Skip to content

Commit d9ddc29

Browse files
committed
optimizer save/load speed
1 parent 235c73e commit d9ddc29

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

paddlenlp/transformers/conversion_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
285285

286286
if isinstance(weight_list[0], np.ndarray):
287287
return np.concatenate([reorder[i] for i in index], axis=axis)
288+
else:
289+
tensor = paddle.concat([reorder[i] for i in index], axis=axis)
288290

289-
return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CUDAPinnedPlace(), False)
291+
if tensor.is_gpu_place():
292+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
293+
return tensor
290294

291295

292296
def naive_fuse_split_tp(
@@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
361365
if isinstance(weight_list[0], np.ndarray):
362366
return np.concatenate(weight_list, axis=-1)
363367
else:
364-
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CUDAPinnedPlace(), False)
368+
tensor = paddle.concat(weight_list, axis=-1)
369+
if tensor.is_gpu_place():
370+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
371+
return tensor
365372
else:
366373
if isinstance(weight_list[0], np.ndarray):
367374
return np.concatenate(weight_list, axis=0)
368375
else:
369-
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CUDAPinnedPlace(), False)
376+
tensor = paddle.concat(weight_list, axis=0)
377+
if tensor.is_gpu_place():
378+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
379+
return tensor
370380

371381

372382
def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):

0 commit comments

Comments
 (0)