Skip to content

Commit e7de0fa

Browse files
authored
add eliminate_transpose arg (#8339)
1 parent ba9d9bd commit e7de0fa

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ class PreTrainingArguments(TrainingArguments):
8686
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
8787
},
8888
)
89+
eliminate_transpose: bool = field(
90+
default=False,
91+
metadata={
92+
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
93+
},
94+
)
8995
job_schedule_profiler_start: int = field(
9096
default=-1,
9197
metadata={"help": "The step to start job_schedule_profiler."},
@@ -132,6 +138,11 @@ def __post_init__(self):
132138
fused_passes.enable = True
133139
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")
134140

141+
if self.eliminate_transpose:
142+
fused_passes = self.strategy.fused_passes
143+
fused_passes.enable = True
144+
fused_passes.fused_passes_list.append("eliminate_transpose")
145+
135146
logger.info(self.strategy)
136147

137148

0 commit comments

Comments
 (0)