Skip to content

Commit d815fce

Browse files
committed
update
1 parent b5c08aa commit d815fce

File tree

4 files changed

+54
-37
lines changed

4 files changed

+54
-37
lines changed

llm/run_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle.distributed import fleet
17+
18+
19+
def dist_gather_tensor_with_gradient(tensor):
20+
if tensor is None:
21+
return None
22+
23+
if paddle.distributed.get_world_size() <= 1:
24+
return tensor
25+
26+
hcg = fleet.get_hybrid_communicate_group()
27+
sharding_group = hcg.get_sharding_parallel_group()
28+
sharding_rank = sharding_group.rank
29+
data_group = hcg.get_data_parallel_group()
30+
data_rank = data_group.rank
31+
32+
if sharding_group.nranks == 1 and data_group.nranks == 1:
33+
return tensor
34+
35+
if sharding_group.nranks > 1:
36+
all_tensors = []
37+
paddle.distributed.all_gather(all_tensors, tensor.contiguous(), group=sharding_group)
38+
all_tensors[sharding_rank] = tensor
39+
all_tensors = paddle.concat(all_tensors, axis=0)
40+
else:
41+
all_tensors = tensor
42+
43+
if data_group.nranks > 1:
44+
final_tensors = []
45+
paddle.distributed.all_gather(final_tensors, all_tensors.contiguous(), group=data_group)
46+
final_tensors[data_rank] = all_tensors
47+
final_tensors = paddle.concat(final_tensors, axis=0)
48+
else:
49+
final_tensors = all_tensors
50+
51+
return final_tensors

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
create_skip_config_for_refined_recompute,
4343
recompute,
4444
)
45-
from paddlenlp.trl.embedding_trainer import dist_gather_tensor_with_gradient
4645

4746
from .. import linear_utils
4847
from ..activations import ACT2FN
4948
from ..conversion_utils import StateDictNameMapping, init_name_mappings
49+
from ..embedding_utils import dist_gather_tensor_with_gradient
5050
from ..linear_utils import Linear
5151
from ..llama import fusion_ops
5252
from ..model_outputs import (

paddlenlp/trl/embedding_trainer.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
MatryoshkaContrastiveLoss,
2424
SimpleContrastiveLoss,
2525
)
26+
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient
2627

2728
__all__ = ["EmbeddingTrainer"]
2829

@@ -178,38 +179,3 @@ def training_step(
178179

179180
loss = self.accum_forward_backward(model)
180181
return loss
181-
182-
183-
def dist_gather_tensor_with_gradient(tensor):
184-
if tensor is None:
185-
return None
186-
187-
if paddle.distributed.get_world_size() <= 1:
188-
return tensor
189-
190-
hcg = fleet.get_hybrid_communicate_group()
191-
sharding_group = hcg.get_sharding_parallel_group()
192-
sharding_rank = sharding_group.rank
193-
data_group = hcg.get_data_parallel_group()
194-
data_rank = data_group.rank
195-
196-
if sharding_group.nranks == 1 and data_group.nranks == 1:
197-
return tensor
198-
199-
if sharding_group.nranks > 1:
200-
all_tensors = []
201-
paddle.distributed.all_gather(all_tensors, tensor.contiguous(), group=sharding_group)
202-
all_tensors[sharding_rank] = tensor
203-
all_tensors = paddle.concat(all_tensors, axis=0)
204-
else:
205-
all_tensors = tensor
206-
207-
if data_group.nranks > 1:
208-
final_tensors = []
209-
paddle.distributed.all_gather(final_tensors, all_tensors.contiguous(), group=data_group)
210-
final_tensors[data_rank] = all_tensors
211-
final_tensors = paddle.concat(final_tensors, axis=0)
212-
else:
213-
final_tensors = all_tensors
214-
215-
return final_tensors

0 commit comments

Comments
 (0)