Skip to content

Commit 31a383a

Browse files
authored
Merge branch 'develop' into dev_20241231_add_deepseekv3
2 parents f9abe9c + 1d74d62 commit 31a383a

File tree

18 files changed

+2846
-30
lines changed

18 files changed

+2846
-30
lines changed

llm/auto_parallel/gpt-3/run_pretrain_auto.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,6 @@ class PreTrainingArguments(AutoTrainingArguments):
9191
default=False,
9292
metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."},
9393
)
94-
use_intermediate_api: bool = field(
95-
default=False,
96-
metadata={"help": "Weather to use auto_parallel intermediate api"},
97-
)
9894

9995
def __post_init__(self):
10096
super().__post_init__()

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,6 @@ class PreTrainingArguments(AutoTrainingArguments):
100100
default=False,
101101
metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."},
102102
)
103-
use_intermediate_api: bool = field(
104-
default=False,
105-
metadata={"help": "Weather to use auto_parallel intermediate api"},
106-
)
107103

108104
def __post_init__(self):
109105
super().__post_init__()

llm/auto_parallel/qwen/run_pretrain_3D_auto.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,6 @@ class PreTrainingArguments(AutoTrainingArguments):
106106
default=False,
107107
metadata={"help": "whether use lazy init for model parameters"},
108108
)
109-
use_intermediate_api: bool = field(
110-
default=False,
111-
metadata={"help": "Weather to use auto_parallel intermediate api"},
112-
)
113109

114110
def __post_init__(self):
115111
super().__post_init__()

paddlenlp/trainer/auto_training_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class AutoTrainingArguments(TrainingArguments):
4747
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
4848
},
4949
)
50-
50+
use_intermediate_api: bool = field(
51+
default=False,
52+
metadata={"help": "Weather to use auto_parallel intermediate api"},
53+
)
5154
refined_ops_patterns: str = field(default=None, metadata={"help": "The pattern of refined recompute."})
5255

5356
def __post_init__(self):

paddlenlp/trainer/unified_checkpoint/load_local.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,15 @@ def load_resolved_archive_file(
282282
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
283283
returned_optim_state_dict[key_name].name = key_name
284284

285-
# master weight cast (only in remove_master_weight)
286-
if has_master_weights and state_dict_master_weight[model_weight_key].dtype != paddle.float32:
287-
state_dict_master_weight[model_weight_key] = paddle.cast(
288-
state_dict_master_weight[model_weight_key], dtype=paddle.float32
289-
)
290-
291285
if has_master_weights:
292286
for key in list(state_dict_master_weight.keys()):
293287
static_name = struct2static_name_mappings[key]
294288
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
289+
# master weight cast (only in remove_master_weight)
290+
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
291+
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
292+
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
293+
)
295294
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
296295

297296
return returned_optim_state_dict

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141

4242
class CheckpointConverter:
4343
def __init__(
44-
self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, trainging_args=None, patch_dict=None
44+
self,
45+
hybrid_parallel_ckpt_path,
46+
state_dict,
47+
parameter_to_structured_name,
48+
trainging_args=None,
49+
patch_dict=None,
50+
local_view_pattern: list | bool = None,
4551
):
4652
self.use_dist = True if paddle.distributed.get_world_size() > 1 else False
4753
self.path = hybrid_parallel_ckpt_path
@@ -85,6 +91,17 @@ def __init__(
8591
self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k]
8692
for k in del_keys:
8793
self.auto_parallel_state_dict.pop(k)
94+
# solve the problem of inconsistent parameter names in moe automatic parallel mode.
95+
if hasattr(trainging_args, "moe_group") and trainging_args.moe_group:
96+
if local_view_pattern is False:
97+
self.local_view_pattern_list = None
98+
else:
99+
if isinstance(local_view_pattern, list):
100+
self.local_view_pattern_list = local_view_pattern
101+
else:
102+
self.local_view_pattern_list = ["experts"]
103+
else:
104+
self.local_view_pattern_list = None
88105

89106
flags = [
90107
["tp degree", self.tp_degree],
@@ -497,6 +514,46 @@ def gen_metadata_and_prepare_source_state_dict(self):
497514
else:
498515
return self.gen_metadata_for_tp_sharded_tensor()
499516

517+
def rename_local_view_state_dict(self, state_dict, file_name):
518+
"""
519+
Rename the key for local views to the key for global views, and return the renamed `state_dict`.
520+
"""
521+
if self.local_view_pattern_list is None:
522+
return state_dict
523+
# case 1: moe_group is mp_group
524+
if self.tp_degree > 1 and self.sharding_degree <= 1:
525+
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
526+
expert_name_old2new = {}
527+
for pattern in self.local_view_pattern_list:
528+
expert_pattern = rf"({pattern}\.)(\d+)"
529+
# extract all experts IDs
530+
expert_ids = set()
531+
for state_name in state_dict.keys():
532+
res = re.search(expert_pattern, state_name)
533+
if res:
534+
expert_ids.add(int(res.group(2)))
535+
expert_num = len(expert_ids)
536+
# construct old name to new name mapping
537+
for state_name in state_dict.keys():
538+
res = re.search(expert_pattern, state_name)
539+
if res:
540+
new_expert_id = int(res.group(2)) % expert_num + tp_rank * expert_num
541+
expert_name_old2new[state_name] = re.sub(
542+
expert_pattern, f"{res.group(1)}{new_expert_id}", state_name
543+
)
544+
# rename state_dict
545+
renamed_state_dict = {
546+
expert_name_old2new[state_name]
547+
if state_name in expert_name_old2new
548+
else state_name: state_dict[state_name]
549+
for state_name in state_dict.keys()
550+
}
551+
552+
return renamed_state_dict
553+
# TODO: add support for sharding
554+
else:
555+
return state_dict
556+
500557
def load_state_dict_and_rename(self):
501558
"""
502559
Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state
@@ -741,11 +798,10 @@ def load_state_dict_and_rename(self):
741798
model_state_file_name = self.get_model_state_file_from(file_name)
742799
assert model_state_file_name is not None
743800
model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name]
744-
renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict)
745-
self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos)
746-
self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict
747-
else:
748-
self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos)
801+
state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict)
802+
renamed_state_dict = self.rename_local_view_state_dict(state_dict, file_name)
803+
self.get_sharded_tensor_infos(file_name, renamed_state_dict, cur_rank_sharded_tensor_infos)
804+
self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict
749805
else:
750806
for file, state_dict in self.cur_rank_loaded_state_dict.items():
751807
# The rule for renaming is to change the master_weights name in the optimizer state to the model weight name,
@@ -897,6 +953,9 @@ def rename(old_name, parameter_to_structured_name):
897953
return None
898954

899955
for key, value in state_dict.items():
956+
# NOTE: Skip the parameters that are not initialized,which are not in the current rank.
957+
if value is None or (isinstance(value, paddle.Tensor) and not value._is_initialized()):
958+
continue
900959
if key in parameter_to_structured_name.values():
901960
new_name = key
902961
else:
@@ -909,7 +968,9 @@ def rename(old_name, parameter_to_structured_name):
909968
def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):
910969
name_mapping = {}
911970
suffix_bucket = {}
912-
assert len(optimizer_state_dict) % len(model_state_keys) == 0
971+
# TODO: After adapting to sharding, remove the code below.
972+
if self.is_sharding_stage3 or (self.sharding_degree > 1 and self.sharding_stage1_v == 2):
973+
assert len(optimizer_state_dict) % len(model_state_keys) == 0
913974
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
914975
suffix_bucket[suffix] = []
915976
for opt_name, opt_value in optimizer_state_dict.items():
@@ -927,10 +988,27 @@ def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_d
927988
for suffix, old_names in suffix_bucket.items():
928989
if len(old_names) == 0:
929990
continue
930-
assert len(old_names) == len(model_state_keys)
931-
for i in range(len(old_names)):
932-
name_mapping[old_names[i]] = model_state_keys[i] + suffix
933-
991+
# TODO: After adapting to sharding, remove the code below.
992+
if self.is_sharding_stage3 or (self.sharding_degree > 1 and self.sharding_stage1_v == 2):
993+
assert len(old_names) == len(model_state_keys)
994+
995+
# NOTE: Handle the case where the number of master_weight elements is not equal to the number of model_state_keys.
996+
if suffix != ".master_weight":
997+
for i in range(len(old_names)):
998+
name_mapping[old_names[i]] = model_state_keys[i] + suffix
999+
else:
1000+
for i in range(len(old_names)):
1001+
param = old_names[i][:-14]
1002+
index = -1
1003+
for idx, opt_name in enumerate(suffix_bucket[".moment1"]):
1004+
if param == opt_name[:-24]:
1005+
index = idx
1006+
break
1007+
if index >= 0:
1008+
name_mapping[old_names[i]] = model_state_keys[index] + suffix
1009+
else:
1010+
raise RuntimeError(f"Can't find {param} in optimizer state dict.")
1011+
# rename state dict
9341012
renamed_state_dict = {}
9351013
for k, v in optimizer_state_dict.items():
9361014
renamed_state_dict[name_mapping[k]] = v

paddlenlp/transformers/__init__.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,96 @@
306306
from .unimo.configuration import *
307307
from .unimo.modeling import *
308308
from .unimo.tokenizer import *
309+
from .unimo.configuration import *
310+
from .xlnet.modeling import *
311+
from .xlnet.tokenizer import *
312+
from .xlnet.configuration import *
313+
from .xlm.modeling import *
314+
from .xlm.tokenizer import *
315+
from .xlm.configuration import *
316+
from .xlm_roberta.modeling import *
317+
from .xlm_roberta.tokenizer import *
318+
from .xlm_roberta.configuration import *
319+
from .gau_alpha.modeling import *
320+
from .gau_alpha.tokenizer import *
321+
from .gau_alpha.configuration import *
322+
from .gemma import *
323+
from .roformerv2.modeling import *
324+
from .roformerv2.tokenizer import *
325+
from .roformerv2.configuration import *
326+
from .optimization import *
327+
from .opt.configuration import *
328+
from .opt.modeling import *
329+
from .auto.modeling import *
330+
from .auto.tokenizer import *
331+
from .auto.processing import *
332+
from .auto.image_processing import *
333+
from .auto.configuration import *
334+
from .codegen.modeling import *
335+
from .codegen.tokenizer import *
336+
from .codegen.configuration import *
337+
from .artist.modeling import *
338+
from .artist.tokenizer import *
339+
from .artist.configuration import *
340+
from .dallebart.modeling import *
341+
from .dallebart.tokenizer import *
342+
from .dallebart.configuration import *
343+
from .clip.modeling import *
344+
from .clip.configuration import *
345+
from .clip.feature_extraction import *
346+
from .clip.tokenizer import *
347+
from .clip.processing import *
348+
from .clip.image_processing import *
349+
from .chineseclip.modeling import *
350+
from .chineseclip.configuration import *
351+
from .chineseclip.feature_extraction import *
352+
from .chineseclip.processing import *
353+
from .chineseclip.image_processing import *
354+
from .chineseclip.tokenizer import *
355+
from .gptj.modeling import *
356+
from .gptj.tokenizer import *
357+
from .gptj.configuration import *
358+
from .pegasus.modeling import *
359+
from .pegasus.tokenizer import *
360+
from .pegasus.configuration import *
361+
from .glm.configuration import *
362+
from .glm.modeling import *
363+
from .glm.tokenizer import *
364+
from .nystromformer.configuration import *
365+
from .nystromformer.modeling import *
366+
from .nystromformer.tokenizer import *
367+
from .bloom.configuration import *
368+
from .bloom.modeling import *
369+
from .bloom.tokenizer import *
370+
from .bloom.tokenizer_fast import *
371+
from .clipseg.configuration import *
372+
from .clipseg.modeling import *
373+
from .clipseg.processing import *
374+
from .clipseg.image_processing import *
375+
from .blip_2.modeling import *
376+
from .blip_2.configuration import *
377+
from .blip_2.processing import *
378+
from .chatglm.configuration import *
379+
from .chatglm.modeling import *
380+
from .chatglm.tokenizer import *
381+
from .chatglm_v2.configuration import *
382+
from .chatglm_v2.modeling import *
383+
from .chatglm_v2.modeling_pp import *
384+
from .chatglm_v2.tokenizer import *
385+
from .speecht5.configuration import *
386+
from .speecht5.modeling import *
387+
from .speecht5.tokenizer import *
388+
from .speecht5.processing import *
389+
from .speecht5.feature_extraction import *
390+
from .minigpt4.modeling import *
391+
from .minigpt4.configuration import *
392+
from .minigpt4.processing import *
393+
from .minigpt4.image_processing import *
394+
from .clap.configuration import *
395+
from .clap.feature_extraction import *
396+
from .clap.modeling import *
397+
from .clap.processing import *
398+
from .visualglm.modeling import *
309399
from .visualglm.configuration import *
310400
from .visualglm.image_processing import *
311401
from .visualglm.modeling import *

paddlenlp/transformers/auto/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
("unimo", "UNIMOConfig"),
116116
("visualglm", "VisualGLMConfig"),
117117
("xlm", "XLMConfig"),
118+
("xlm-roberta", "XLMRobertaConfig"),
118119
("xlnet", "XLNetConfig"),
119120
("yuan", "YuanConfig"),
120121
]
@@ -206,6 +207,7 @@
206207
("unimo", "UNIMO"),
207208
("visualglm", "VisualGLM"),
208209
("xlm", "XLM"),
210+
("xlm-roberta", "XLMRoberta"),
209211
("xlnet", "XLNet"),
210212
("yuan", "Yuan"),
211213
]

paddlenlp/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
("UNIMO", "unimo"),
9797
("XLNet", "xlnet"),
9898
("XLM", "xlm"),
99+
("XLMRoberta", "xlm_roberta"),
99100
("GPT", "gpt"),
100101
("GLM", "glm"),
101102
("MT5", "mt5"),

paddlenlp/transformers/auto/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
("squeezebert", "SqueezeBertTokenizer"),
116116
("t5", "T5Tokenizer"),
117117
("xlm", "XLMTokenizer"),
118+
("xlm_roberta", "XLMRobertaTokenizer"),
118119
("xlnet", "XLNetTokenizer"),
119120
("bert_japanese", "BertJapaneseTokenizer"),
120121
("bigbird", "BigBirdTokenizer"),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .configuration import *
16+
from .modeling import *
17+
from .tokenizer import *

0 commit comments

Comments
 (0)