diff --git a/src/transformers/video_utils.py b/src/transformers/video_utils.py index 505f018f47c3..af9d80bab76b 100644 --- a/src/transformers/video_utils.py +++ b/src/transformers/video_utils.py @@ -696,11 +696,13 @@ def group_videos_by_shape( grouped_videos_index = {} for i, video in enumerate(videos): shape = video.shape[-2::] + num_frames = video.shape[-4] # video format BTCHW + shape = (num_frames, *shape) if shape not in grouped_videos: grouped_videos[shape] = [] grouped_videos[shape].append(video) grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1) - # stack videos with the same shape + # stack videos with the same size and number of frames grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()} return grouped_videos, grouped_videos_index diff --git a/tests/utils/test_video_utils.py b/tests/utils/test_video_utils.py index 441838ffcab4..21a5b44ff8e1 100644 --- a/tests/utils/test_video_utils.py +++ b/tests/utils/test_video_utils.py @@ -30,7 +30,7 @@ require_torchvision, require_vision, ) -from transformers.video_utils import make_batched_videos +from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos if is_torch_available(): @@ -43,9 +43,9 @@ from transformers.video_utils import VideoMetadata, load_video -def get_random_video(height, width, return_torch=False): +def get_random_video(height, width, num_frames=8, return_torch=False): random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) - video = np.array(([random_frame] * 8)) + video = np.array(([random_frame] * num_frames)) if return_torch: # move channel first return torch.from_numpy(video).permute(0, 3, 1, 2) @@ -189,6 +189,53 @@ def test_convert_to_rgb(self): rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1)) self.assertEqual(rgb_video.shape, (8, 3, 20, 20)) + def test_group_and_reorder_videos(self): + """Tests that videos can be grouped by frame size and number of frames""" + video_1 = get_random_video(20, 20, num_frames=3, return_torch=True) + video_2 = get_random_video(20, 20, num_frames=5, return_torch=True) + + # Group two videos of same size but different number of frames + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2]) + self.assertEqual(len(grouped_videos), 2) + + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) + self.assertTrue(len(regrouped_videos), 2) + self.assertEqual(video_1.shape, regrouped_videos[0].shape) + + # Group two videos of different size but same number of frames + video_3 = get_random_video(15, 20, num_frames=3, return_torch=True) + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3]) + self.assertEqual(len(grouped_videos), 2) + + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) + self.assertTrue(len(regrouped_videos), 2) + self.assertEqual(video_1.shape, regrouped_videos[0].shape) + + # Group all three videos where some have same size or same frame count + # But since none have frames and sizes identical, we'll have 3 groups + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3]) + self.assertEqual(len(grouped_videos), 3) + + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) + self.assertTrue(len(regrouped_videos), 3) + self.assertEqual(video_1.shape, regrouped_videos[0].shape) + + # Group if we had some videos with identical shapes + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3]) + self.assertEqual(len(grouped_videos), 2) + + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) + self.assertTrue(len(regrouped_videos), 2) + self.assertEqual(video_1.shape, regrouped_videos[0].shape) + + # Group if we had all videos with identical shapes + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1]) + self.assertEqual(len(grouped_videos), 1) + + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) + self.assertTrue(len(regrouped_videos), 1) + self.assertEqual(video_1.shape, regrouped_videos[0].shape) + @require_vision @require_av