File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -86,6 +86,12 @@ class PreTrainingArguments(TrainingArguments):
86
86
"help" : "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
87
87
},
88
88
)
89
+ eliminate_transpose : bool = field (
90
+ default = False ,
91
+ metadata = {
92
+ "help" : "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
93
+ },
94
+ )
89
95
job_schedule_profiler_start : int = field (
90
96
default = - 1 ,
91
97
metadata = {"help" : "The step to start job_schedule_profiler." },
@@ -132,6 +138,11 @@ def __post_init__(self):
132
138
fused_passes .enable = True
133
139
fused_passes .fused_passes_list .append ("fused_linear_param_grad_add_pass" )
134
140
141
+ if self .eliminate_transpose :
142
+ fused_passes = self .strategy .fused_passes
143
+ fused_passes .enable = True
144
+ fused_passes .fused_passes_list .append ("eliminate_transpose" )
145
+
135
146
logger .info (self .strategy )
136
147
137
148
You can’t perform that action at this time.
0 commit comments