Skip to content

Commit c6e5459

Browse files
authored
[Distributed] metric calculation supports tp logits (#8370)
1 parent ac117a1 commit c6e5459

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

llm/finetune_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def main():
140140
if not training_args.autotuner_benchmark:
141141
model = AutoModelForCausalLMPipe.from_pretrained(
142142
model_args.model_name_or_path,
143-
tensor_parallel_output=False,
143+
tensor_parallel_output=training_args.tensor_parallel_output,
144144
tensor_parallel_degree=training_args.tensor_parallel_degree,
145145
tensor_parallel_rank=training_args.tensor_parallel_rank,
146146
use_flash_attention=model_args.use_flash_attention,
@@ -152,7 +152,7 @@ def main():
152152
# NOTE(gongenlei): new add autotuner_benchmark
153153
model_config = AutoConfig.from_pretrained(
154154
model_args.model_name_or_path,
155-
tensor_parallel_output=False,
155+
tensor_parallel_output=training_args.tensor_parallel_output,
156156
tensor_parallel_degree=training_args.tensor_parallel_degree,
157157
tensor_parallel_rank=training_args.tensor_parallel_rank,
158158
dtype=dtype,
@@ -163,7 +163,7 @@ def main():
163163
else:
164164
model_config = AutoConfig.from_pretrained(
165165
model_args.model_name_or_path,
166-
tensor_parallel_output=False,
166+
tensor_parallel_output=training_args.tensor_parallel_output,
167167
tensor_parallel_degree=training_args.tensor_parallel_degree,
168168
tensor_parallel_rank=training_args.tensor_parallel_rank,
169169
dtype=dtype,

llm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ def prediction_step(
211211
# keepdim in order to maintain the same shape as logits
212212
if isinstance(logits, (list, tuple)):
213213
logits = logits[0]
214+
# all gather logits when enabling tensor_parallel_output
215+
if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output:
216+
hcg = fleet.get_hybrid_communicate_group()
217+
model_parallel_group = hcg.get_model_parallel_group()
218+
gathered_logits = []
219+
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
220+
logits = paddle.concat(gathered_logits, axis=-1)
214221
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
215222

216223
loss = None

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class TrainingArguments:
803803
default=False,
804804
metadata={"help": "whether to run distributed training in auto parallel mode"},
805805
)
806+
tensor_parallel_output: Optional[bool] = field(
807+
default=False,
808+
metadata={"help": "whether to output logits in distributed status"},
809+
)
806810

807811
def __post_init__(self):
808812
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))

0 commit comments

Comments
 (0)