From 2b553b4e703a6b343fb73228abec890a0d56f306 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Nov 2024 11:46:33 +0800 Subject: [PATCH 1/2] fix load missing keys --- paddlenlp/trainer/unified_checkpoint/check_completion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddlenlp/trainer/unified_checkpoint/check_completion.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py index 626d25875740..833303759648 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_completion.py +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -186,6 +186,8 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, state_dict = get_expected_state_dict(model) for key in state_dict.keys(): + if model._keys_to_ignore_on_load_massing is not None and key in model._keys_to_ignore_on_load_missing: + continue if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None) param_rank = param2rank.get(static_name, None) From 9cea3069425d0a7f1767acb9bbeeae9cbc10dff8 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 3 Dec 2024 19:21:08 +0800 Subject: [PATCH 2/2] Update check_completion.py --- paddlenlp/trainer/unified_checkpoint/check_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/unified_checkpoint/check_completion.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py index 833303759648..8e83425b9c39 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_completion.py +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -186,7 +186,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, state_dict = get_expected_state_dict(model) for key in state_dict.keys(): - if model._keys_to_ignore_on_load_massing is not None and key in model._keys_to_ignore_on_load_missing: + if model._keys_to_ignore_on_load_missing is not None and key in model._keys_to_ignore_on_load_missing: continue if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None)