@@ -1447,24 +1447,41 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
1447
1447
process_index = self .args .dataset_rank ,
1448
1448
)
1449
1449
1450
- return _DataLoader (
1451
- eval_dataset ,
1452
- batch_size = self .args .per_device_eval_batch_size ,
1453
- collate_fn = self .data_collator ,
1454
- num_workers = self .args .dataloader_num_workers ,
1455
- )
1450
+ if self .args .distributed_dataloader :
1451
+ return _DataLoader (
1452
+ eval_dataset ,
1453
+ batch_size = self .args .per_device_eval_batch_size ,
1454
+ collate_fn = self .data_collator ,
1455
+ num_workers = self .args .dataloader_num_workers ,
1456
+ eval = True ,
1457
+ )
1458
+ else :
1459
+ return _DataLoader (
1460
+ eval_dataset ,
1461
+ batch_size = self .args .per_device_eval_batch_size ,
1462
+ collate_fn = self .data_collator ,
1463
+ num_workers = self .args .dataloader_num_workers ,
1464
+ )
1456
1465
1457
1466
eval_sampler = self ._get_eval_sampler (eval_dataset )
1458
1467
1459
1468
if self .args .distributed_dataloader :
1460
1469
logger .info ("Eval using DistDataLoader." )
1461
1470
1462
- return _DataLoader (
1463
- eval_dataset ,
1464
- batch_sampler = eval_sampler ,
1465
- collate_fn = self .data_collator ,
1466
- num_workers = self .args .dataloader_num_workers ,
1467
- )
1471
+ return _DataLoader (
1472
+ eval_dataset ,
1473
+ batch_sampler = eval_sampler ,
1474
+ collate_fn = self .data_collator ,
1475
+ num_workers = self .args .dataloader_num_workers ,
1476
+ eval = True ,
1477
+ )
1478
+ else :
1479
+ return _DataLoader (
1480
+ eval_dataset ,
1481
+ batch_sampler = eval_sampler ,
1482
+ collate_fn = self .data_collator ,
1483
+ num_workers = self .args .dataloader_num_workers ,
1484
+ )
1468
1485
1469
1486
def get_test_dataloader (self , test_dataset : Dataset ) -> DataLoader :
1470
1487
"""
@@ -1497,25 +1514,42 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1497
1514
process_index = self .args .dataset_rank ,
1498
1515
)
1499
1516
1500
- return _DataLoader (
1501
- test_dataset ,
1502
- batch_size = self .args .per_device_eval_batch_size * self .world_size ,
1503
- collate_fn = self .data_collator , # _get_collator_with_removed_columns
1504
- num_workers = self .args .dataloader_num_workers ,
1505
- )
1517
+ if self .args .distributed_dataloader :
1518
+ return _DataLoader (
1519
+ test_dataset ,
1520
+ batch_size = self .args .per_device_eval_batch_size * self .world_size ,
1521
+ collate_fn = self .data_collator , # _get_collator_with_removed_columns
1522
+ num_workers = self .args .dataloader_num_workers ,
1523
+ eval = True ,
1524
+ )
1525
+ else :
1526
+ return _DataLoader (
1527
+ test_dataset ,
1528
+ batch_size = self .args .per_device_eval_batch_size * self .world_size ,
1529
+ collate_fn = self .data_collator , # _get_collator_with_removed_columns
1530
+ num_workers = self .args .dataloader_num_workers ,
1531
+ )
1506
1532
1507
1533
test_sampler = self ._get_eval_sampler (test_dataset )
1508
1534
1509
1535
if self .args .distributed_dataloader :
1510
1536
logger .info ("Test using DistDataLoader." )
1511
1537
1512
- # We use the same batch_size as for eval.
1513
- return _DataLoader (
1514
- test_dataset ,
1515
- batch_sampler = test_sampler ,
1516
- collate_fn = self .data_collator ,
1517
- drop_last = self .args .dataloader_drop_last ,
1518
- )
1538
+ # We use the same batch_size as for eval.
1539
+ return _DataLoader (
1540
+ test_dataset ,
1541
+ batch_sampler = test_sampler ,
1542
+ collate_fn = self .data_collator ,
1543
+ drop_last = self .args .dataloader_drop_last ,
1544
+ eval = True ,
1545
+ )
1546
+ else :
1547
+ return _DataLoader (
1548
+ test_dataset ,
1549
+ batch_sampler = test_sampler ,
1550
+ collate_fn = self .data_collator ,
1551
+ drop_last = self .args .dataloader_drop_last ,
1552
+ )
1519
1553
1520
1554
def create_optimizer_and_scheduler (self , num_training_steps : int ):
1521
1555
"""
0 commit comments