Skip to content

Commit 77ed530

Browse files
committed
add test_make_batched_videos
1 parent f0da40b commit 77ed530

File tree

2 files changed

+114
-7
lines changed

2 files changed

+114
-7
lines changed

src/transformers/image_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,6 @@ def make_nested_list_of_images(
292292
return [list(image) for image in images]
293293

294294
# If it's a single image, convert it to a list of lists
295-
if is_pil_image(images):
296-
return [[images]]
297-
298295
if is_valid_image(images):
299296
if is_pil_image(images) or images.ndim == 3:
300297
return [[images]]
@@ -317,15 +314,15 @@ def make_batched_videos(videos) -> VideoInput:
317314
return videos
318315

319316
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
320-
if is_pil_image(videos[0]):
317+
if is_pil_image(videos[0]) or videos[0].ndim == 3:
321318
return [videos]
322-
elif len(videos[0].shape) == 4:
319+
elif videos[0].ndim == 4:
323320
return [list(video) for video in videos]
324321

325322
elif is_valid_image(videos):
326-
if is_pil_image(videos):
323+
if is_pil_image(videos) or videos.ndim == 3:
327324
return [[videos]]
328-
elif len(videos.shape) == 4:
325+
elif videos.ndim == 4:
329326
return [list(videos)]
330327

331328
raise ValueError(f"Could not make batched video from {videos}")

tests/utils/test_image_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from transformers.image_utils import (
3232
ChannelDimension,
3333
get_channel_dimension_axis,
34+
make_batched_videos,
3435
make_flat_list_of_images,
3536
make_list_of_images,
3637
make_nested_list_of_images,
@@ -396,6 +397,115 @@ def test_make_nested_list_of_images_torch(self):
396397
self.assertEqual(len(images_list[0]), 4)
397398
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
398399

400+
def test_make_batched_videos_pil(self):
401+
# Test a single image is converted to a list of 1 video with 1 frame
402+
pil_image = get_random_image(16, 32)
403+
videos_list = make_batched_videos(pil_image)
404+
self.assertIsInstance(videos_list[0], list)
405+
self.assertEqual(len(videos_list[0]), 1)
406+
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
407+
408+
# Test a list of images is converted to a list of 1 video
409+
images = [get_random_image(16, 32) for _ in range(4)]
410+
videos_list = make_batched_videos(images)
411+
self.assertIsInstance(videos_list[0], list)
412+
self.assertEqual(len(videos_list), 1)
413+
self.assertEqual(len(videos_list[0]), 4)
414+
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
415+
416+
# Test a nested list of images is not modified
417+
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
418+
videos_list = make_nested_list_of_images(images)
419+
self.assertIsInstance(videos_list[0], list)
420+
self.assertEqual(len(videos_list), 2)
421+
self.assertEqual(len(videos_list[0]), 2)
422+
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
423+
424+
def test_make_batched_videos_numpy(self):
425+
# Test a single image is converted to a list of 1 video with 1 frame
426+
images = np.random.randint(0, 256, (16, 32, 3))
427+
videos_list = make_nested_list_of_images(images)
428+
self.assertIsInstance(videos_list[0], list)
429+
self.assertEqual(len(videos_list), 1)
430+
self.assertTrue(np.array_equal(videos_list[0][0], images))
431+
432+
# Test a 4d array of images is converted to a a list of 1 video
433+
images = np.random.randint(0, 256, (4, 16, 32, 3))
434+
videos_list = make_nested_list_of_images(images)
435+
self.assertIsInstance(videos_list[0], list)
436+
self.assertIsInstance(videos_list[0][0], np.ndarray)
437+
self.assertEqual(len(videos_list), 1)
438+
self.assertEqual(len(videos_list[0]), 4)
439+
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
440+
441+
# Test a list of images is converted to a list of videos
442+
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
443+
videos_list = make_nested_list_of_images(images)
444+
self.assertIsInstance(videos_list[0], list)
445+
self.assertEqual(len(videos_list), 1)
446+
self.assertEqual(len(videos_list[0]), 4)
447+
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
448+
449+
# Test a nested list of images is left unchanged
450+
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
451+
videos_list = make_nested_list_of_images(images)
452+
self.assertIsInstance(videos_list[0], list)
453+
self.assertEqual(len(videos_list), 2)
454+
self.assertEqual(len(videos_list[0]), 2)
455+
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
456+
457+
# Test a list of 4d array images is converted to a list of videos
458+
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
459+
videos_list = make_nested_list_of_images(images)
460+
self.assertIsInstance(videos_list[0], list)
461+
self.assertIsInstance(videos_list[0][0], np.ndarray)
462+
self.assertEqual(len(videos_list), 2)
463+
self.assertEqual(len(videos_list[0]), 4)
464+
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
465+
466+
@require_torch
467+
def test_make_batched_videos_torch(self):
468+
# Test a single image is converted to a list of 1 video with 1 frame
469+
images = torch.randint(0, 256, (16, 32, 3))
470+
videos_list = make_nested_list_of_images(images)
471+
self.assertIsInstance(videos_list[0], list)
472+
self.assertEqual(len(videos_list[0]), 1)
473+
self.assertTrue(np.array_equal(videos_list[0][0], images))
474+
475+
# Test a 4d tensor of images is converted to a list of 1 video
476+
images = torch.randint(0, 256, (4, 16, 32, 3))
477+
videos_list = make_nested_list_of_images(images)
478+
self.assertIsInstance(videos_list[0], list)
479+
self.assertIsInstance(videos_list[0][0], torch.Tensor)
480+
self.assertEqual(len(videos_list), 1)
481+
self.assertEqual(len(videos_list[0]), 4)
482+
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
483+
484+
# Test a list of images is converted to a list of videos
485+
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
486+
videos_list = make_nested_list_of_images(images)
487+
self.assertIsInstance(videos_list[0], list)
488+
self.assertEqual(len(videos_list), 1)
489+
self.assertEqual(len(videos_list[0]), 4)
490+
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
491+
492+
# Test a nested list of images is left unchanged
493+
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
494+
videos_list = make_nested_list_of_images(images)
495+
self.assertIsInstance(videos_list[0], list)
496+
self.assertEqual(len(videos_list), 2)
497+
self.assertEqual(len(videos_list[0]), 2)
498+
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
499+
500+
# Test a list of 4d tensor images is converted to a list of videos
501+
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
502+
videos_list = make_nested_list_of_images(images)
503+
self.assertIsInstance(videos_list[0], list)
504+
self.assertIsInstance(videos_list[0][0], torch.Tensor)
505+
self.assertEqual(len(videos_list), 2)
506+
self.assertEqual(len(videos_list[0]), 4)
507+
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
508+
399509
@require_torch
400510
def test_conversion_torch_to_array(self):
401511
feature_extractor = ImageFeatureExtractionMixin()

0 commit comments

Comments
 (0)