12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
-
16
15
from collections import OrderedDict
17
16
18
17
from paddle .distributed .fleet .model import PipelineParallel
@@ -46,6 +45,25 @@ def get_index_layer_func():
46
45
return _GLOBAL_INDEX_LAYER_FUNC
47
46
48
47
48
+ _GLOBAL_SNAME_TO_TNAME_FUNC = None
49
+
50
+
51
+ def register_sname_to_tname_func (func ):
52
+ global _GLOBAL_SNAME_TO_TNAME_FUNC
53
+ _GLOBAL_SNAME_TO_TNAME_FUNC = func
54
+
55
+
56
+ def has_register_sname_to_tname_func ():
57
+ global _GLOBAL_SNAME_TO_TNAME_FUNC
58
+ return _GLOBAL_SNAME_TO_TNAME_FUNC is not None
59
+
60
+
61
+ def get_sname_to_tname_func ():
62
+ global _GLOBAL_SNAME_TO_TNAME_FUNC
63
+ assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None , "sname to tname func is not registered yet"
64
+ return _GLOBAL_SNAME_TO_TNAME_FUNC
65
+
66
+
49
67
class LayerNameScope :
50
68
"""
51
69
layer name scope for a layer, layer name of the same kind of layer will be named consecutively
@@ -206,6 +224,7 @@ def __init__(self):
206
224
self ._segments = OrderedDict ()
207
225
self ._layer_to_segment = OrderedDict ()
208
226
self ._param_to_tname = OrderedDict ()
227
+ self ._wname_to_rname = OrderedDict ()
209
228
210
229
def add_segment (self , start_index , end_index ):
211
230
segment = PipeLineSegment (start_index , end_index )
@@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names):
218
237
segment = self ._layer_to_segment [layer_index ]
219
238
segment .add_layer (layer_name , param_names )
220
239
221
- def build_name_mapping (self ):
240
+ def build_name_mapping (self , sname_to_tname = None ):
222
241
for (k , segment ) in self ._segments .items ():
223
242
for (i , layer ) in segment .layers .items ():
224
243
for param in layer .params .items ():
225
244
(param_name , tensor_name ) = param
226
245
# map to a new name
227
246
n_name = self ._rename_mgr .get_new_param_name (layer .name , tensor_name )
247
+ if sname_to_tname is not None :
248
+ if param_name in sname_to_tname .keys ():
249
+ self ._wname_to_rname [param_name ] = sname_to_tname [param_name ]
228
250
# logger.info(f"{param_name} {tensor_name}=>{n_name}")
229
251
self ._param_to_tname [param_name ] = (tensor_name , n_name )
230
252
231
253
def map_name (self , param_name , t_name ):
232
254
assert param_name in self ._param_to_tname
233
255
tensor_name , n_name = self ._param_to_tname [param_name ]
256
+ if param_name in self ._wname_to_rname :
257
+ n_name = self ._wname_to_rname [param_name ]
234
258
assert tensor_name == t_name
235
259
return n_name
236
260
@@ -261,6 +285,11 @@ def __init__(
261
285
self ._index_layers ()
262
286
263
287
stage_segments = self ._segment ()
288
+ if has_register_sname_to_tname_func ():
289
+ self ._sname_to_tname = get_sname_to_tname_func ()(pp_model )
290
+ else :
291
+ self ._sname_to_tname = None
292
+
264
293
for (i , stage_seg ) in enumerate (stage_segments ):
265
294
pipe_stage = PipeLineStage ()
266
295
self ._stages .append (pipe_stage )
@@ -275,7 +304,7 @@ def __init__(
275
304
self ._layer_name_to_stage [layer_name ] = i
276
305
277
306
for stage in self ._stages :
278
- stage .build_name_mapping ()
307
+ stage .build_name_mapping (self . _sname_to_tname )
279
308
280
309
def _index_layers (self ):
281
310
for layer_name in self ._param_names_by_layer .keys ():
0 commit comments