@@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
285
285
286
286
if isinstance (weight_list [0 ], np .ndarray ):
287
287
return np .concatenate ([reorder [i ] for i in index ], axis = axis )
288
+ else :
289
+ tensor = paddle .concat ([reorder [i ] for i in index ], axis = axis )
288
290
289
- return paddle .concat ([reorder [i ] for i in index ], axis = axis )._copy_to (paddle .CUDAPinnedPlace (), False )
291
+ if tensor .is_gpu_place ():
292
+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False )
293
+ return tensor
290
294
291
295
292
296
def naive_fuse_split_tp (
@@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
361
365
if isinstance (weight_list [0 ], np .ndarray ):
362
366
return np .concatenate (weight_list , axis = - 1 )
363
367
else :
364
- return paddle .concat (weight_list , axis = - 1 )._copy_to (paddle .CUDAPinnedPlace (), False )
368
+ tensor = paddle .concat (weight_list , axis = - 1 )
369
+ if tensor .is_gpu_place ():
370
+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False )
371
+ return tensor
365
372
else :
366
373
if isinstance (weight_list [0 ], np .ndarray ):
367
374
return np .concatenate (weight_list , axis = 0 )
368
375
else :
369
- return paddle .concat (weight_list , axis = 0 )._copy_to (paddle .CUDAPinnedPlace (), False )
376
+ tensor = paddle .concat (weight_list , axis = 0 )
377
+ if tensor .is_gpu_place ():
378
+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False )
379
+ return tensor
370
380
371
381
372
382
def normal_fuse_split_tp (weight , tensor_parallel_degree , tensor_parallel_rank = None , is_column = True ):
0 commit comments