@@ -99,7 +99,7 @@ def find_class(self, mod_name, name):
99
99
return _rebuild_tensor_stage
100
100
101
101
# pytorch_lightning tensor builder
102
- if mod_name == "pytorch_lightning" :
102
+ if "pytorch_lightning" in mod_name :
103
103
return dumpy
104
104
return super ().find_class (mod_name , name )
105
105
@@ -219,23 +219,19 @@ def extract_maybe_dict(result):
219
219
for res in result :
220
220
extract_maybe_dict (res )
221
221
elif isinstance (result , TensorMeta ):
222
- metadata .append (result )
222
+ if result not in metadata :
223
+ metadata .append (result )
223
224
224
225
extract_maybe_dict (result_stage1 )
225
226
metadata = sorted (metadata , key = lambda x : x .key )
226
227
# 3. parse the tensor of pytorch weight file
227
228
stage1_key_to_tensor = {}
228
229
content_size = os .stat (path ).st_size
229
230
with open (path , "rb" ) as file_handler :
230
- prefix_key = read_prefix_key (file_handler , content_size ).decode ("latin" )
231
231
file_handler .seek (pre_offset )
232
-
233
232
for tensor_meta in metadata :
234
233
key = tensor_meta .key
235
- # eg: archive/data/1FB
236
- filename = f"{ prefix_key } /data/{ key } "
237
- seek_by_string (file_handler , filename , content_size )
238
- file_handler .seek (2 , 1 )
234
+ seek_by_string (file_handler , "FB" , content_size )
239
235
240
236
padding_offset = np .frombuffer (file_handler .read (2 )[:1 ], dtype = np .uint8 )[0 ]
241
237
file_handler .seek (padding_offset , 1 )
0 commit comments