@@ -121,6 +121,7 @@ class PipelinePretrainedModel(PretrainedModel):
121
121
_pipeline_name_mapping = None
122
122
123
123
def __init__ (self , config , * args , ** kwargs ):
124
+ raise ValueError ()
124
125
super ().__init__ (config , * args , ** kwargs )
125
126
126
127
def add_sequential_layer (self , layer_desc , name_prefix = "" ):
@@ -138,23 +139,39 @@ def _set_pipeline_name_mapping(self, mappings=None):
138
139
else :
139
140
mapping = {}
140
141
state_dict_keys = list (super ().state_dict ().keys ())
142
+ first_key = state_dict_keys [0 ].split ("." )
143
+ # if use virtual pp_degree, the prefix is like 0.0.xxx
144
+ # else it will be like 0.xxx
145
+ use_virtual_pp_degree = first_key [0 ].isdigit () and first_key [1 ].isdigit ()
146
+
141
147
prefixs = self .get_sequential_name_prefixs ()
142
148
for k in state_dict_keys :
143
149
name_splited = k .split ("." )
144
- name_splited [0 ] = prefixs [name_splited [0 ]]
145
- mapping ["." .join (name_splited )] = k
150
+ if use_virtual_pp_degree :
151
+ idx = str (int (name_splited [0 ]) + int (name_splited [1 ]))
152
+ single_name = [prefixs [idx ]]
153
+ single_name .extend (name_splited [2 :])
154
+ else :
155
+ idx = name_splited [0 ]
156
+ single_name = [prefixs [idx ]]
157
+ single_name .extend (name_splited [1 :])
158
+ mapping ["." .join (single_name )] = k
159
+
146
160
self ._pipeline_name_mapping = mapping
147
161
148
162
return self ._pipeline_name_mapping
149
163
150
164
def state_dict (self , * args , ** kwargs ):
151
165
state_dict = super ().state_dict (* args , ** kwargs )
152
- prefixs = self .get_sequential_name_prefixs ()
166
+
167
+ if self ._pipeline_name_mapping is None :
168
+ self ._set_pipeline_name_mapping ()
169
+ assert len (self ._pipeline_name_mapping ) > 0 , "The pipeline stage must have parameters!"
170
+ pp_to_single_mapping = {v : k for k , v in self ._pipeline_name_mapping .items ()}
171
+
153
172
for k in list (state_dict .keys ()):
154
173
v = state_dict .pop (k )
155
- name_splited = k .split ("." )
156
- name_splited [0 ] = prefixs [name_splited [0 ]]
157
- state_dict ["." .join (name_splited )] = v
174
+ state_dict [pp_to_single_mapping [k ]] = v
158
175
159
176
return state_dict
160
177
@@ -169,7 +186,8 @@ def set_state_dict(self, state_dict, *args, **kwargs):
169
186
continue
170
187
state_dict [self ._pipeline_name_mapping [k ]] = v
171
188
172
- return super ().set_state_dict (state_dict , * args , ** kwargs )
189
+ ret = super ().set_state_dict (state_dict , * args , ** kwargs )
190
+ return ret
173
191
174
192
175
193
class LlamaForCausalLMPipe (PipelinePretrainedModel , PipelineLayer ):
@@ -182,28 +200,25 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
182
200
config_class = LlamaConfig
183
201
184
202
_get_tensor_parallel_mappings = LlamaPretrainedModel ._get_tensor_parallel_mappings
203
+
185
204
# NO base_model_prefix !!!!
186
205
187
206
def __init__ (
188
207
self ,
189
208
config ,
190
- # num_partitions=1,
191
- # topology=None,
192
- use_recompute = None ,
193
- # fused_linear=False,
194
- # fuse_attn_qkv=False,
209
+ # use_recompute=None,
195
210
# scale_qk_by_layer_num=True,
196
- recompute_granularity = "full" ,
197
- virtual_pp_degree = 1 ,
211
+ # recompute_granularity="full",
212
+ # virtual_pp_degree=4 ,
198
213
# sequence_parallel=False,
199
214
# no_recompute_layers=None,
200
215
pp_recompute_interval = 1 ,
201
- # use_flash_attn=False,
202
- # fused_softmax_with_triangular=False,
203
216
):
204
217
self .config = config
205
- if use_recompute is None :
206
- use_recompute = self .config .use_recompute
218
+
219
+ use_recompute = self .config .use_recompute
220
+ recompute_granularity = self .config .recompute_granularity
221
+ virtual_pp_degree = self .config .virtual_pp_degree
207
222
208
223
hcg = get_hcg ()
209
224
tensor_parallel_degree = max (hcg .get_model_parallel_world_size (), 1 )
0 commit comments