Skip to content

Commit 91936bd

Browse files
ForFisheswestfish
andauthored
[Cherry-pick] Add release grad & sharding format & decorate_exclude_layers (#8545)
* fix bug of sharding format (#8483) * Add release grad for SD (#8478) * add decorate_exclude_layers --------- Co-authored-by: westfish <westfish@126.com>
1 parent 1cf780e commit 91936bd

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def _wrap_amp_model(self, args, model):
404404
models=model,
405405
level=self.args.fp16_opt_level,
406406
dtype=self.amp_dtype,
407-
excluded_layers=QuantizationLinear,
407+
excluded_layers=[QuantizationLinear] + self._decorate_exclude_layers(model),
408408
)
409409
# for pipeline mode and pure tensor parallel
410410
if self.args.pipeline_parallel_degree > 1 or (self.args.tensor_parallel_degree > 1 and self.sharding is None):
@@ -998,8 +998,14 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
998998
pipeline_parallel_config = (
999999
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
10001000
)
1001+
sharding_parallel_config = (
1002+
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
1003+
)
10011004
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
1002-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
1005+
enable_release_grads = (
1006+
"enable_release_grads" in pipeline_parallel_config
1007+
or "enable_release_grads" in sharding_parallel_config
1008+
)
10031009

10041010
# Case 3: Pipeline parallel mode, overlap with dp
10051011
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
@@ -1058,11 +1064,12 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
10581064
if optimizer_was_run:
10591065
self.lr_scheduler.step()
10601066

1061-
if enable_release_grads and args.pipeline_parallel_degree > 1:
1067+
if enable_release_grads:
10621068
self.optimizer.clear_grad(set_to_zero=False)
1063-
for _, buffers in model._chunk_2_comm_buffers.items():
1064-
for buffer in buffers:
1065-
buffer._clear_grad_storage()
1069+
if args.pipeline_parallel_degree > 1:
1070+
for _, buffers in model._chunk_2_comm_buffers.items():
1071+
for buffer in buffers:
1072+
buffer._clear_grad_storage()
10661073
else:
10671074
self.optimizer.clear_grad()
10681075

@@ -1728,6 +1735,17 @@ def num_examples(self, dataloader: DataLoader) -> int:
17281735
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
17291736
return len(dataloader) * self.args.per_device_train_batch_size
17301737

1738+
def _decorate_exclude_layers(self, model: nn.Layer):
1739+
"""
1740+
Exclude layers from the model for paddle.amp.decorate.
1741+
Args:
1742+
model (`nn.Layer`): The model to exclude layers from.
1743+
Returns:
1744+
A list of excluded layers.
1745+
"""
1746+
exclude_layers = []
1747+
return exclude_layers
1748+
17311749
def _wrap_model(self, model, training=True):
17321750

17331751
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
@@ -1747,7 +1765,7 @@ def _wrap_model(self, model, training=True):
17471765
optimizers=self.optimizer,
17481766
level=self.args.fp16_opt_level,
17491767
dtype=self.amp_dtype,
1750-
excluded_layers=QuantizationLinear,
1768+
excluded_layers=[QuantizationLinear] + self._decorate_exclude_layers(model),
17511769
)
17521770

17531771
if self.optimizer is None:

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class TrainingArguments:
266266
enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.
267267
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
268268
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
269+
enable_release_graHEADds, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
269270
recompute (`bool`, *optional*, defaults to `False`):
270271
Recompute the forward pass to calculate gradients. Used for saving memory.
271272
Only support for networks with transformer blocks.
@@ -1198,6 +1199,7 @@ def is_segment_parallel_supported():
11981199
"disable_stage1_reduce_avg",
11991200
"enable_stage1_broadcast_overlap",
12001201
"enable_stage1_allgather_overlap",
1202+
"enable_release_grads",
12011203
]:
12021204
raise ValueError(
12031205
f"Found unknown pipeline mode config {x}, "
@@ -1218,6 +1220,9 @@ def is_segment_parallel_supported():
12181220
if "split_param" in sharding_parallel_config:
12191221
strategy.hybrid_configs["sharding_configs"].split_param = True
12201222

1223+
if "enable_release_grads" in sharding_parallel_config:
1224+
strategy.hybrid_configs["sharding_configs"].release_gradients = True
1225+
12211226
if self.pipeline_parallel_degree == 1:
12221227
strategy.hybrid_configs["sharding_configs"].tensor_fusion = (
12231228
True if "enable_stage1_tensor_fusion" in sharding_parallel_config else False
@@ -1671,7 +1676,7 @@ def pipeline_parallel_rank(self):
16711676
return 0
16721677

16731678
def _format_name(self, prefix, rank, degree):
1674-
size = max(2, len(str(degree)))
1679+
size = 2
16751680
return f"{prefix}{rank:0>{size}d}"
16761681

16771682
@property

0 commit comments

Comments
 (0)