Skip to content

Commit 465ce1d

Browse files
authored
fix sp (#9795)
1 parent bd2d9d0 commit 465ce1d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

paddlenlp/peft/lora/loraga_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
import paddle.distributed as dist
1717
from paddle.distributed import fleet
1818

19+
try:
20+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
21+
register_sequence_parallel_allreduce_hooks,
22+
)
23+
except:
24+
pass
25+
1926
from paddlenlp.peft import LoRAModel
2027
from paddlenlp.peft.lora.lora_layers import (
2128
ColumnParallelLoRALinear,
@@ -83,6 +90,11 @@ def estimate_gradient(self, model: PretrainedModel):
8390
def _wrap_model(self, model):
8491
"""Wrap Model without optimizer, support dp, tp and sharding"""
8592

93+
if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
94+
register_sequence_parallel_allreduce_hooks(
95+
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
96+
)
97+
8698
in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
8799
in_sharding_parallel_mode = self.sharding is not None
88100
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,11 +414,6 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
414414
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
415415
)
416416

417-
if args.sequence_parallel:
418-
register_sequence_parallel_allreduce_hooks(
419-
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
420-
)
421-
422417
self.do_grad_scaling = False
423418
self.enable_autocast_context_manager = False
424419
if args.fp16 or args.bf16:
@@ -1987,6 +1982,11 @@ def _wrap_model(self, model, training=True):
19871982
else:
19881983
model, self.optimizer = decorated
19891984

1985+
if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
1986+
register_sequence_parallel_allreduce_hooks(
1987+
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
1988+
)
1989+
19901990
if self.args.world_size == 1:
19911991
if self.args.amp_master_grad:
19921992
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)

0 commit comments

Comments
 (0)