Skip to content

Commit 487428b

Browse files
authored
Fix shared weights sync for PipelineLayer (#7772)
* fix shared weights sync * fix typo
1 parent ff1e910 commit 487428b

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@
4141
)
4242
from huggingface_hub.utils import EntryNotFoundError
4343
from paddle import Tensor
44-
from paddle.distributed.fleet.meta_parallel.parallel_layers import SharedLayerDesc
44+
from paddle.distributed.fleet.meta_parallel.parallel_layers import (
45+
PipelineLayer,
46+
SharedLayerDesc,
47+
)
4548
from paddle.nn import Embedding, Layer
4649

4750
# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later
@@ -933,6 +936,18 @@ def _post_init(self, original_init, *args, **kwargs):
933936
):
934937
self.init_weights()
935938

939+
# Note:
940+
# 1. PipelineLayer will create parameters for each layer and
941+
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
942+
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
943+
# synchronize the shared parameters.
944+
# However, `self._init_weights` will re-initialize the parameters without
945+
# synchronizing the shared parameters. If the following step does not load a checkpoint,
946+
# the shared parameters will be different.
947+
948+
if isinstance(self, PipelineLayer):
949+
self._synchronize_shared_weights()
950+
936951
def _init_weights(self, layer):
937952
"""
938953
Initialize the weights. This method should be overridden by derived class.

0 commit comments

Comments
 (0)