Skip to content

Commit a139758

Browse files
DesmonDaylugimzzz
andauthored
[DistDataloader] Update implementation, add nested.py (#8380)
* Fix sharding overlap bug * [DistDataloader] Update implementation, add nested.py * Fix pipeline * add first try * update dataloader --------- Co-authored-by: lugimzzz <zhenglujing@baidu.com>
1 parent 46f49df commit a139758

File tree

5 files changed

+218
-225
lines changed

5 files changed

+218
-225
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 70 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numpy as np
1615
import paddle
1716
from paddle.distributed import fleet
1817

1918
from paddlenlp.utils.log import logger
20-
21-
_MAX_DATA_DIM = 64
19+
from paddlenlp.utils.nested import (
20+
nested_broadcast_tensor,
21+
nested_copy_place,
22+
nested_empty_tensor,
23+
nested_reduce_tensor,
24+
)
2225

2326

2427
class DummyDataset(paddle.io.Dataset):
@@ -53,6 +56,7 @@ def __init__(
5356
timeout=0,
5457
worker_init_fn=None,
5558
persistent_workers=False,
59+
eval=False,
5660
):
5761

5862
if dataset is None:
@@ -62,12 +66,15 @@ def __init__(
6266
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)
6367

6468
self._hcg = fleet.get_hybrid_communicate_group()
69+
self.eval = eval
6570

6671
# Init pp data comm group.
6772
if self._hcg.get_pipe_parallel_world_size() > 1:
6873
self._pp_data_group = self._init_dataloader_comm_group()
74+
self._pp_group = self._hcg.get_pipe_parallel_group()
6975
else:
7076
self._pp_data_group = None
77+
self._pp_group = None
7178

7279
self.mp_group = self._hcg.get_model_parallel_group()
7380
self.mp_rank = self._hcg.get_model_parallel_rank()
@@ -78,10 +85,6 @@ def __init__(
7885
sharding_rank = self._hcg.get_sharding_parallel_rank()
7986
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)
8087

81-
# When needed other data types, we can modify dtype_list.
82-
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
83-
self._data_keys_list, self._data_keys_size = None, None
84-
8588
if self._need_data:
8689
self._dataloader = paddle.io.DataLoader(
8790
dataset,
@@ -127,7 +130,6 @@ def _init_dataloader_comm_group(self):
127130
parallel_groups = topo.get_comm_list("pipe")
128131

129132
for group in parallel_groups:
130-
# only first rank and last rank
131133
ranks = [group[0], group[-1]]
132134
comm_group = paddle.distributed.new_group(ranks=ranks)
133135
if paddle.distributed.get_rank() in ranks:
@@ -137,127 +139,70 @@ def _init_dataloader_comm_group(self):
137139
def __iter__(self):
138140
return self
139141

140-
def __next__(self):
141-
data_keys_size = [0 for i in range(len(self.dtype_list))]
142-
if self._need_data:
143-
data = next(self._dataloader_iter)
144-
data_keys = list(data.keys())
145-
146-
for key in data_keys:
147-
if data[key].dtype not in self.dtype_list:
148-
raise ValueError(
149-
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
142+
def _broadcast_data(self, data):
143+
process_rank = paddle.distributed.get_rank()
144+
if self.mp_group.nranks > 1:
145+
if process_rank == self.mp_src_rank:
146+
fake_data = [nested_reduce_tensor(data)]
147+
else:
148+
if data is not None:
149+
logger.warning(
150+
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
150151
)
151-
152-
data_list, data_keys_list = [], []
153-
for i, dtype in enumerate(self.dtype_list):
154-
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
155-
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
156-
data_keys_size = [len(keys) for keys in data_keys_list]
157-
158-
# Broadcast data keys size.
159-
if self._data_keys_size is None:
160-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
161-
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
162-
if self._pp_data_group is not None:
163-
paddle.distributed.broadcast_object_list(
164-
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
165-
)
166-
self._data_keys_size = data_keys_size
167-
168-
if not self._need_data:
169-
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
170-
171-
# Broadcast data keys name.
172-
if self._data_keys_list is None:
173-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
174-
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
175-
if self._pp_data_group is not None:
176-
paddle.distributed.broadcast_object_list(
177-
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
178-
)
179-
self._data_keys_list = data_keys_list
180-
181-
# Broadcast data.
182-
if not self._need_data:
183-
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
184-
185-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
186-
for i, dtype in enumerate(self.dtype_list):
187-
if self._data_keys_size[i] > 0:
188-
data_list[i] = broadcast_data_list(
189-
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
152+
fake_data = [None]
153+
if self._pp_group is not None:
154+
if process_rank == self._pp_group.ranks[0]:
155+
fake_data = [nested_reduce_tensor(data)]
156+
else:
157+
if data is not None:
158+
logger.warning(
159+
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
190160
)
191-
192-
if self._pp_data_group is not None:
193-
# Note(daisimng): In last stage of pp, we don't need input_ids.
194-
# It will be removed in future.
195-
for i, dtype in enumerate(self.dtype_list):
196-
if self._data_keys_size[i] > 0:
197-
data_list[i] = broadcast_data_list(
198-
data_list[i],
199-
dtype,
200-
self.pp_rank,
201-
self._pp_data_group,
202-
self._pp_data_group.ranks[0],
203-
)
204-
205-
out_data = {}
206-
for keys, datas in zip(self._data_keys_list, data_list):
207-
out_data.update([(k, d) for k, d in zip(keys, datas)])
208-
209-
return out_data
210-
211-
212-
def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
213-
"""
214-
Broadcast data from src_rank to all ranks in comm_group.
215-
"""
216-
# Move to GPU and broadcast.
217-
size_cpu = []
218-
if comm_rank == 0:
219-
for data in data_list:
220-
size_cpu.append(len(data.shape))
221-
size_cpu += data.shape
222-
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
223-
size_cuda = paddle.to_tensor(size_cpu)
224-
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()
225-
226-
size_cpu = size_cuda.tolist()
227-
i = 0
228-
numel = 0
229-
sizes = []
230-
while size_cpu[i] > 0:
231-
rank = size_cpu[i]
232-
this_size = size_cpu[i + 1 : i + 1 + rank]
233-
numel += int(np.prod(this_size))
234-
sizes.append(this_size)
235-
i += rank + 1
236-
237-
if comm_rank == 0:
238-
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
239-
data.dtype, datatype
240-
)
241-
if paddle.is_compiled_with_cuda():
242-
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
161+
fake_data = [None]
162+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
163+
paddle.distributed.broadcast_object_list(
164+
fake_data,
165+
src=self.mp_src_rank,
166+
group=self.mp_group,
167+
)
168+
if self._pp_group is not None:
169+
paddle.distributed.broadcast_object_list(
170+
fake_data,
171+
src=self._pp_group.ranks[0],
172+
group=self._pp_group,
173+
)
243174
else:
244-
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
175+
fake_data = [None]
245176

246-
assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
247-
else:
248-
if paddle.is_compiled_with_cuda():
249-
data_b = paddle.empty([numel], dtype=datatype).cuda()
250-
else:
251-
data_b = paddle.empty([numel], dtype=datatype)
177+
fake_data = fake_data[0]
178+
if fake_data is None:
179+
raise StopIteration
252180

253-
# Broadcast
254-
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
181+
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
182+
if self.mp_group.nranks > 1:
183+
if process_rank != self.mp_src_rank:
184+
data = nested_empty_tensor(fake_data)
185+
if dst_pp_group is not None:
186+
if process_rank != dst_pp_group.ranks[0]:
187+
data = nested_empty_tensor(fake_data)
255188

256-
ret = []
257-
offset = 0
258-
for size in sizes:
259-
numel = int(np.prod(size))
260-
ret.append(data_b[offset : offset + numel].reshape(size))
261-
offset += numel
189+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
190+
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
191+
if dst_pp_group is not None:
192+
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)
193+
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
194+
if data is None:
195+
data = {}
262196

263-
return ret
197+
return data
198+
199+
def __next__(self):
200+
data = None
201+
if self._need_data:
202+
try:
203+
data = next(self._dataloader_iter)
204+
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
205+
except:
206+
pass
207+
data = self._broadcast_data(data)
208+
return data

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
SAFE_WEIGHTS_NAME,
6464
)
6565
from paddlenlp.utils.log import logger
66+
from paddlenlp.utils.nested import nested_copy, nested_copy_place
6667

6768
if is_safetensors_available():
6869
# from safetensors import safe_open
@@ -1880,29 +1881,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
18801881
return new_actions
18811882

18821883

1883-
def nested_copy(inputs):
1884-
if isinstance(inputs, dict):
1885-
outputs = {}
1886-
for key in list(inputs.keys()):
1887-
outputs[key] = nested_copy(inputs[key])
1888-
return outputs
1889-
return inputs
1890-
1891-
1892-
def nested_copy_place(inputs, place=None, blocking=False):
1893-
if isinstance(inputs, dict):
1894-
outputs = {}
1895-
for key in list(inputs.keys()):
1896-
outputs[key] = nested_copy_place(inputs[key], place, blocking)
1897-
return outputs
1898-
if isinstance(inputs, paddle.Tensor):
1899-
if inputs.place._equals(place):
1900-
return inputs
1901-
else:
1902-
return inputs._copy_to(place, blocking)
1903-
return inputs
1904-
1905-
19061884
def flatten_list(nested_list):
19071885
flattened_list = []
19081886
for item in nested_list:

0 commit comments

Comments
 (0)