|
30 | 30 | require_torchvision,
|
31 | 31 | require_vision,
|
32 | 32 | )
|
33 |
| -from transformers.video_utils import make_batched_videos |
| 33 | +from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos |
34 | 34 |
|
35 | 35 |
|
36 | 36 | if is_torch_available():
|
|
43 | 43 | from transformers.video_utils import VideoMetadata, load_video
|
44 | 44 |
|
45 | 45 |
|
46 |
| -def get_random_video(height, width, return_torch=False): |
| 46 | +def get_random_video(height, width, num_frames=8, return_torch=False): |
47 | 47 | random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
48 |
| - video = np.array(([random_frame] * 8)) |
| 48 | + video = np.array(([random_frame] * num_frames)) |
49 | 49 | if return_torch:
|
50 | 50 | # move channel first
|
51 | 51 | return torch.from_numpy(video).permute(0, 3, 1, 2)
|
@@ -189,6 +189,53 @@ def test_convert_to_rgb(self):
|
189 | 189 | rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
190 | 190 | self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
191 | 191 |
|
| 192 | + def test_group_and_reorder_videos(self): |
| 193 | + """Tests that videos can be grouped by frame size and number of frames""" |
| 194 | + video_1 = get_random_video(20, 20, num_frames=3, return_torch=True) |
| 195 | + video_2 = get_random_video(20, 20, num_frames=5, return_torch=True) |
| 196 | + |
| 197 | + # Group two videos of same size but different number of frames |
| 198 | + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2]) |
| 199 | + self.assertEqual(len(grouped_videos), 2) |
| 200 | + |
| 201 | + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) |
| 202 | + self.assertTrue(len(regrouped_videos), 2) |
| 203 | + self.assertEqual(video_1.shape, regrouped_videos[0].shape) |
| 204 | + |
| 205 | + # Group two videos of different size but same number of frames |
| 206 | + video_3 = get_random_video(15, 20, num_frames=3, return_torch=True) |
| 207 | + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3]) |
| 208 | + self.assertEqual(len(grouped_videos), 2) |
| 209 | + |
| 210 | + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) |
| 211 | + self.assertTrue(len(regrouped_videos), 2) |
| 212 | + self.assertEqual(video_1.shape, regrouped_videos[0].shape) |
| 213 | + |
| 214 | + # Group all three videos where some have same size or same frame count |
| 215 | + # But since none have frames and sizes identical, we'll have 3 groups |
| 216 | + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3]) |
| 217 | + self.assertEqual(len(grouped_videos), 3) |
| 218 | + |
| 219 | + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) |
| 220 | + self.assertTrue(len(regrouped_videos), 3) |
| 221 | + self.assertEqual(video_1.shape, regrouped_videos[0].shape) |
| 222 | + |
| 223 | + # Group if we had some videos with identical shapes |
| 224 | + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3]) |
| 225 | + self.assertEqual(len(grouped_videos), 2) |
| 226 | + |
| 227 | + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) |
| 228 | + self.assertTrue(len(regrouped_videos), 2) |
| 229 | + self.assertEqual(video_1.shape, regrouped_videos[0].shape) |
| 230 | + |
| 231 | + # Group if we had all videos with identical shapes |
| 232 | + grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1]) |
| 233 | + self.assertEqual(len(grouped_videos), 1) |
| 234 | + |
| 235 | + regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index) |
| 236 | + self.assertTrue(len(regrouped_videos), 1) |
| 237 | + self.assertEqual(video_1.shape, regrouped_videos[0].shape) |
| 238 | + |
192 | 239 |
|
193 | 240 | @require_vision
|
194 | 241 | @require_av
|
|
0 commit comments