From a2094dc30d53c200ea87efdd847f435bce02c822 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 1 Jul 2024 15:21:38 +0800 Subject: [PATCH] =?UTF-8?q?num=5Fsamples=20=E5=90=91=E4=B8=8B=E5=8E=BB?= =?UTF-8?q?=E6=95=B4,=E9=98=B2=E6=AD=A2prefrech=E9=A2=84=E5=8F=96=E6=97=B6?= =?UTF-8?q?=E5=80=99=E8=B6=85=E8=BF=87=E6=95=B0=E6=8D=AE=E9=9B=86=E6=9C=80?= =?UTF-8?q?=E5=A4=A7=E9=95=BF=E5=BA=A6...?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlenlp/utils/batch_sampler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddlenlp/utils/batch_sampler.py b/paddlenlp/utils/batch_sampler.py index 1cee8d1cb4c6..619904a6d33f 100644 --- a/paddlenlp/utils/batch_sampler.py +++ b/paddlenlp/utils/batch_sampler.py @@ -14,8 +14,6 @@ from __future__ import division, print_function -import math - import paddle __all__ = ["DistributedBatchSampler"] @@ -110,7 +108,7 @@ def __init__( # In pre-training mode when using distributed dataloader, the input dataset can be None. We should handle this situation. self.num_samples = 0 else: - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) + self.num_samples = int(len(self.dataset) * 1.0 / self.nranks) self.total_size = self.num_samples * self.nranks def get_start_end_idx(self): @@ -125,7 +123,7 @@ def __iter__(self): self.consumed_samples, self.nranks, ) - self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks)) + self.remain_num_samples = int((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks) self.remain_total_size = self.remain_num_samples * self.nranks self.batch_size_times_rank_size = self.batch_size * self.nranks