Skip to content

Commit 09c652a

Browse files
authored
changes (#6077)
1 parent 935f102 commit 09c652a

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,12 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver
338338
support_conversion (bool): whether support converting pytorch weight file to paddle weight file
339339
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
340340
"""
341-
if hf_file_exists(repo_id, "model_state.pdparams", subfolder=subfolder):
342-
file_name = "model_state.pdparams"
341+
if hf_file_exists(repo_id, PADDLE_WEIGHT_FILE_NAME, subfolder=subfolder):
342+
file_name = PADDLE_WEIGHT_FILE_NAME
343343
elif hf_file_exists(repo_id, PYTORCH_WEIGHT_FILE_NAME, subfolder=subfolder):
344344
if not support_conversion:
345345
raise EntryNotFoundError(
346-
f"can not download `model_state.pdparams from https://huggingface.co/{repo_id}` "
346+
f"can not download `{PADDLE_WEIGHT_FILE_NAME} from https://huggingface.co/{repo_id}` "
347347
"and current model doesn't support conversion from pytorch weight file to paddle weight file"
348348
)
349349
file_name = PYTORCH_WEIGHT_FILE_NAME
@@ -448,7 +448,7 @@ class is a pretrained model class adding layers on top of the base model,
448448
# TODO: more flexible resource handle, namedtuple with fields as:
449449
# resource_name, saved_file, handle_name_for_load(None for used as __init__
450450
# arguments), handle_name_for_save
451-
resource_files_names = {"model_state": "model_state.pdparams"}
451+
resource_files_names = {"model_state": PADDLE_WEIGHT_FILE_NAME}
452452
pretrained_resource_files_map = {}
453453
base_model_prefix = ""
454454
main_input_name = "input_ids"
@@ -583,6 +583,22 @@ def model_name_list(self):
583583
# Todo: return all model name
584584
return list(self.pretrained_init_configuration.keys())
585585

586+
def get_memory_footprint(self, return_buffers=True):
587+
r"""
588+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
589+
Useful to benchmark the memory footprint of the current model and design some tests.
590+
591+
Arguments:
592+
return_buffers (`bool`, *optional*, defaults to `True`):
593+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
594+
are tensors that do not require gradients and not registered as parameters
595+
"""
596+
mem = sum([param.numel().item() * param.element_size() for param in self.parameters()])
597+
if return_buffers:
598+
mem_bufs = sum([buf.numel().item() * buf.element_size() for buf in self.buffers()])
599+
mem = mem + mem_bufs
600+
return mem
601+
586602
def get_input_embeddings(self) -> nn.Embedding:
587603
"""get input embedding of model
588604
@@ -1275,7 +1291,7 @@ def from_pretrained(
12751291

12761292
else:
12771293
# 4. loading the state dict
1278-
if config.tensor_parallel_degree > 1 and model_weight_file.endswith("model_state.pdparams"):
1294+
if config.tensor_parallel_degree > 1 and model_weight_file.endswith(PADDLE_WEIGHT_FILE_NAME):
12791295
model_state_dict = cls.convert_tensor_parallel(model_weight_file, config)
12801296
else:
12811297
model_state_dict = paddle.load(model_weight_file, return_numpy=load_state_as_np)

0 commit comments

Comments
 (0)