Skip to content

Commit 7304171

Browse files
authored
fix load torch (#4383)
1 parent 5a44636 commit 7304171

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

paddlenlp/utils/serialization.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def find_class(self, mod_name, name):
9999
return _rebuild_tensor_stage
100100

101101
# pytorch_lightning tensor builder
102-
if mod_name == "pytorch_lightning":
102+
if "pytorch_lightning" in mod_name:
103103
return dumpy
104104
return super().find_class(mod_name, name)
105105

@@ -219,23 +219,19 @@ def extract_maybe_dict(result):
219219
for res in result:
220220
extract_maybe_dict(res)
221221
elif isinstance(result, TensorMeta):
222-
metadata.append(result)
222+
if result not in metadata:
223+
metadata.append(result)
223224

224225
extract_maybe_dict(result_stage1)
225226
metadata = sorted(metadata, key=lambda x: x.key)
226227
# 3. parse the tensor of pytorch weight file
227228
stage1_key_to_tensor = {}
228229
content_size = os.stat(path).st_size
229230
with open(path, "rb") as file_handler:
230-
prefix_key = read_prefix_key(file_handler, content_size).decode("latin")
231231
file_handler.seek(pre_offset)
232-
233232
for tensor_meta in metadata:
234233
key = tensor_meta.key
235-
# eg: archive/data/1FB
236-
filename = f"{prefix_key}/data/{key}"
237-
seek_by_string(file_handler, filename, content_size)
238-
file_handler.seek(2, 1)
234+
seek_by_string(file_handler, "FB", content_size)
239235

240236
padding_offset = np.frombuffer(file_handler.read(2)[:1], dtype=np.uint8)[0]
241237
file_handler.seek(padding_offset, 1)

0 commit comments

Comments
 (0)