Skip to content

[Cherry pick] Sharding reshard function enhancement #8544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,12 @@
base_weight_name=weight_name,
model_wrapped=self.model_wrapped,
)
self.model.set_state_dict(state_dict)
old_state_dict = self.model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
if k not in old_state_dict or id(v) != id(old_state_dict[k]):
new_state_dict[k] = v
self.model.set_state_dict(new_state_dict)

Check warning on line 573 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L568-L573

Added lines #L568 - L573 were not covered by tests
else:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

Expand Down Expand Up @@ -887,7 +892,8 @@

npu_accelerate_plugin(self.optimizer)

self.timers and self.timers("read-data").start()
if self.args.ignore_data_skip:
self.timers and self.timers("read-data").start()

Check warning on line 896 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L895-L896

Added lines #L895 - L896 were not covered by tests

for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
Expand All @@ -903,7 +909,9 @@
inputs = split_inputs_sequence_dim(inputs)
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
inputs = split_inputs_sequence_dim_load_balance(inputs)
self.timers and self.timers("read-data").stop()
if self.args.ignore_data_skip:
self.timers and self.timers("read-data").stop()

Check warning on line 913 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L912-L913

Added lines #L912 - L913 were not covered by tests

os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)

Expand Down Expand Up @@ -1094,7 +1102,9 @@

if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.timers and self.timers("read-data").start()

if self.args.ignore_data_skip:
self.timers and self.timers("read-data").start()

Check warning on line 1107 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1106-L1107

Added lines #L1106 - L1107 were not covered by tests

if step < 0:
logger.warning(
Expand Down Expand Up @@ -2467,10 +2477,15 @@
if state_dict is None:
state_dict = self.model.state_dict()

self._save_ckpt_func(
state_dict,
os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)),
)
if self.args.should_save_sharding_stage1_model:
state_dict, _, _ = self.sharding_io.manipulate_state_dict_and_config(

Check warning on line 2481 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2480-L2481

Added lines #L2480 - L2481 were not covered by tests
unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict
)
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.sharded_name_suffix())

Check warning on line 2484 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2484

Added line #L2484 was not covered by tests
else:
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)

Check warning on line 2486 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2486

Added line #L2486 was not covered by tests

self._save_ckpt_func(state_dict, os.path.join(output_dir, variant))

Check warning on line 2488 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2488

Added line #L2488 was not covered by tests
else:
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
config_to_save = None
Expand Down
18 changes: 12 additions & 6 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,13 @@
node_model_state = reshard_pp(node_model_state)
return reshard_sharding(node_model_state)

def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False):
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None):
weight_name_suffix = self.args.sharded_name_suffix()

state_dict = model_to_save.state_dict()
if self.args.should_save_sharding_stage1_model:
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)
if state_dict is None:
state_dict = model_to_save.state_dict()
if self.args.should_save_sharding_stage1_model:
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)

Check warning on line 330 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L327-L330

Added lines #L327 - L330 were not covered by tests

config_to_save = None
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down Expand Up @@ -384,7 +385,7 @@

path = os.path.join(dir, MODEL_META_NAME)
with open(path, "w") as f:
json.dump(model_meta, f, indent=4)
json.dump(model_meta, f)

Check warning on line 388 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L388

Added line #L388 was not covered by tests

def _get_distributed_strategy(self):
pp_degree = 1
Expand Down Expand Up @@ -544,13 +545,18 @@
pp_overlap = unwrap_optimizer(self.optimizer, DygraphShardingOptimizerV2).pp_overlap

model = self.model
structure_name_mapping = {k: v.name for (k, v) in model.state_dict().items()}
structure_name_mapping = {}
param_meta = {}
for k, v in model.state_dict().items():
structure_name_mapping[k] = v.name
param_meta[k] = (v.shape, int(v.dtype))

Check warning on line 552 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L548-L552

Added lines #L548 - L552 were not covered by tests

sharding_metas = {}
sharding_meta = {}

sharding_meta["param2rank"] = param2rank
sharding_meta["structure_name_mapping"] = structure_name_mapping
sharding_meta["param_meta"] = param_meta

Check warning on line 559 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L559

Added line #L559 was not covered by tests
sharding_meta["sharding_strategy"] = sharding_strategy
sharding_meta["enable_overlap"] = pp_overlap
suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}"
Expand Down
Loading