File tree Expand file tree Collapse file tree 3 files changed +6
-5
lines changed Expand file tree Collapse file tree 3 files changed +6
-5
lines changed Original file line number Diff line number Diff line change 16
16
import sys
17
17
from dataclasses import dataclass , field
18
18
from functools import partial
19
+ from typing import Optional
19
20
20
21
import paddle
21
22
from argument import (
@@ -66,6 +67,10 @@ class FinetuneArguments(TrainingArguments):
66
67
default = 0 ,
67
68
metadata = {"help" : "The steps use to control the learing rate." },
68
69
)
70
+ tensor_parallel_output : Optional [bool ] = field (
71
+ default = False ,
72
+ metadata = {"help" : "whether to output logits in distributed status" },
73
+ )
69
74
70
75
71
76
def read_local_dataset (path ):
Original file line number Diff line number Diff line change @@ -212,7 +212,7 @@ def prediction_step(
212
212
if isinstance (logits , (list , tuple )):
213
213
logits = logits [0 ]
214
214
# all gather logits when enabling tensor_parallel_output
215
- if self .args .tensor_parallel_degree > 1 and self .args . tensor_parallel_output :
215
+ if self .args .tensor_parallel_degree > 1 and getattr ( self .args , " tensor_parallel_output" , False ) :
216
216
hcg = fleet .get_hybrid_communicate_group ()
217
217
model_parallel_group = hcg .get_model_parallel_group ()
218
218
gathered_logits = []
Original file line number Diff line number Diff line change @@ -787,10 +787,6 @@ class TrainingArguments:
787
787
default = False ,
788
788
metadata = {"help" : "whether to run distributed training in auto parallel mode" },
789
789
)
790
- tensor_parallel_output : Optional [bool ] = field (
791
- default = False ,
792
- metadata = {"help" : "whether to output logits in distributed status" },
793
- )
794
790
use_expert_parallel : Optional [bool ] = field (
795
791
default = False ,
796
792
metadata = {"help" : "Enable MoE (Mixture of Experts) expert parallel training" },
You can’t perform that action at this time.
0 commit comments