Skip to content

Commit 19fdb75

Browse files
authored
[video utils] group and reorder by number of frames (#38374)
fix
1 parent b0735dc commit 19fdb75

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

src/transformers/video_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,11 +696,13 @@ def group_videos_by_shape(
696696
grouped_videos_index = {}
697697
for i, video in enumerate(videos):
698698
shape = video.shape[-2::]
699+
num_frames = video.shape[-4] # video format BTCHW
700+
shape = (num_frames, *shape)
699701
if shape not in grouped_videos:
700702
grouped_videos[shape] = []
701703
grouped_videos[shape].append(video)
702704
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
703-
# stack videos with the same shape
705+
# stack videos with the same size and number of frames
704706
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
705707
return grouped_videos, grouped_videos_index
706708

tests/utils/test_video_utils.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
require_torchvision,
3131
require_vision,
3232
)
33-
from transformers.video_utils import make_batched_videos
33+
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
3434

3535

3636
if is_torch_available():
@@ -43,9 +43,9 @@
4343
from transformers.video_utils import VideoMetadata, load_video
4444

4545

46-
def get_random_video(height, width, return_torch=False):
46+
def get_random_video(height, width, num_frames=8, return_torch=False):
4747
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
48-
video = np.array(([random_frame] * 8))
48+
video = np.array(([random_frame] * num_frames))
4949
if return_torch:
5050
# move channel first
5151
return torch.from_numpy(video).permute(0, 3, 1, 2)
@@ -189,6 +189,53 @@ def test_convert_to_rgb(self):
189189
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
190190
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
191191

192+
def test_group_and_reorder_videos(self):
193+
"""Tests that videos can be grouped by frame size and number of frames"""
194+
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
195+
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
196+
197+
# Group two videos of same size but different number of frames
198+
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
199+
self.assertEqual(len(grouped_videos), 2)
200+
201+
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
202+
self.assertTrue(len(regrouped_videos), 2)
203+
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
204+
205+
# Group two videos of different size but same number of frames
206+
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
207+
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
208+
self.assertEqual(len(grouped_videos), 2)
209+
210+
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
211+
self.assertTrue(len(regrouped_videos), 2)
212+
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
213+
214+
# Group all three videos where some have same size or same frame count
215+
# But since none have frames and sizes identical, we'll have 3 groups
216+
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
217+
self.assertEqual(len(grouped_videos), 3)
218+
219+
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
220+
self.assertTrue(len(regrouped_videos), 3)
221+
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
222+
223+
# Group if we had some videos with identical shapes
224+
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
225+
self.assertEqual(len(grouped_videos), 2)
226+
227+
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
228+
self.assertTrue(len(regrouped_videos), 2)
229+
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
230+
231+
# Group if we had all videos with identical shapes
232+
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
233+
self.assertEqual(len(grouped_videos), 1)
234+
235+
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
236+
self.assertTrue(len(regrouped_videos), 1)
237+
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
238+
192239

193240
@require_vision
194241
@require_av

0 commit comments

Comments
 (0)