Skip to content

Commit 38ef1f6

Browse files
authored
Cherry pick lora fix (#7826)
* fix lora (#7824) * [llm]support qlora pp (#7801) * supprt qlora pp * fix scale dtype
1 parent d4acbfc commit 38ef1f6

File tree

4 files changed

+25
-10
lines changed

4 files changed

+25
-10
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
231231

232232
if self.is_pipelinemodel:
233233
self.model._single_to_pp_mapping = None
234-
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
234+
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
235235
merge_tensor_parallel = False
236236
logger.warning(
237237
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
238238
)
239-
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
239+
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
240240
merge_tensor_parallel = False
241241
logger.warning(
242242
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."

paddlenlp/quantization/quantization_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self.quant_scale = self.create_parameter(
8484
shape=[out_features],
8585
attr=scale_attr,
86-
dtype="float32",
86+
dtype=self._dtype,
8787
is_bias=False,
8888
)
8989
if self.quant_algo in ["fp4", "nf4"]:
@@ -231,7 +231,7 @@ def __init__(
231231
self.quant_scale = self.create_parameter(
232232
shape=[self.output_size_per_partition],
233233
attr=scale_attr,
234-
dtype="float32",
234+
dtype=self._dtype,
235235
is_bias=False,
236236
)
237237
self.quant_scale.is_distributed = True if self.is_mp else False
@@ -345,7 +345,7 @@ def __init__(
345345
self.quant_scale = self.create_parameter(
346346
shape=[out_features],
347347
attr=scale_attr,
348-
dtype="float32",
348+
dtype=self._dtype,
349349
is_bias=False,
350350
)
351351
self.quant_scale.is_distributed = True if self.is_mp else False

paddlenlp/quantization/quantization_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def convert_to_quantize_state_dict_with_check(state_dict, quantization_linear_li
109109
raise ValueError(
110110
f"{quant_weight_name} should be {paddle.int8} in state_dict but received dtype {state_dict[quant_weight_name].dtype}."
111111
)
112-
if state_dict[quant_scale_name].dtype != paddle.float32:
112+
if (
113+
state_dict[quant_scale_name].dtype != paddle.float16
114+
and state_dict[quant_scale_name].dtype != paddle.bfloat16
115+
):
113116
raise ValueError(
114-
f"{quant_scale_name} should be {paddle.float32} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
117+
f"{quant_scale_name} should be {paddle.float16} or {paddle.bfloat16} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
115118
)
116119
elif weight_name in state_dict:
117120
target_weight = state_dict.pop(weight_name).cast(dtype)

paddlenlp/transformers/model_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,9 +1759,15 @@ def _load_pretrained_model(
17591759
loaded_keys, quantization_linear_list, config.quantization_config
17601760
)
17611761
if keep_in_fp32_modules is None:
1762-
keep_in_fp32_modules = ["quant_scale"]
1762+
keep_in_fp32_modules = (
1763+
["quant_scale"] if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] else None
1764+
)
17631765
else:
1764-
keep_in_fp32_modules += ["quant_scale"]
1766+
keep_in_fp32_modules = (
1767+
keep_in_fp32_modules + ["quant_scale"]
1768+
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
1769+
else keep_in_fp32_modules
1770+
)
17651771

17661772
missing_keys = list(set(expected_keys) - set(loaded_keys))
17671773
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
@@ -2173,7 +2179,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
21732179
logger.info("Loaded weights file from disk, setting weights to model.")
21742180

21752181
# Check if `_keep_in_fp32_modules` is not None
2176-
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and dtype == "float16"
2182+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
2183+
dtype == "float16" or dtype == "bfloat16"
2184+
)
21772185

21782186
if is_sharded:
21792187
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
@@ -2208,6 +2216,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
22082216
quantization_config=config.quantization_config,
22092217
llm_int8_threshold=config.quantization_config.llm_int8_threshold,
22102218
)
2219+
quantization_linear_list = []
2220+
for key in model.state_dict().keys():
2221+
if "quant_weight" in key:
2222+
quantization_linear_list.append(key[:-13])
22112223

22122224
model, missing_keys, unexpected_keys, mismatched_keys = cls._load_pretrained_model(
22132225
model=model,

0 commit comments

Comments
 (0)