Skip to content

Commit 0289118

Browse files
yuanlehomeming1753
andauthored
[LLM_INFER] Support quantized model from bos and fix docs (#9197)
* llama A8W8 support skip_scale * support top_k * qwen2 support skip scales * merge develop * add comment * support quantized model from bos --------- Co-authored-by: minghaipeng <minghaipeng@baidu.com>
1 parent b2e4db2 commit 0289118

File tree

16 files changed

+571
-361
lines changed

16 files changed

+571
-361
lines changed

llm/docs/predict/best_practices.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用
44

55
**GEMM 优化**
66

7-
- `FLAGS_enable_blaslt_global_search`:int8 gemm是否开启全局调优,默认值为0,表示不开启。设置为1,PaddleNLP 会在推理过程中使用`FLAGS_cublaslt_device_best_config`中记录的最优GEMM配置
7+
- `FLAGS_enable_blaslt_global_search`:int8 gemm 是否开启全局调优,默认值为0,表示不开启。设置为1,PaddleNLP 会在推理过程中动态搜索最优的 gemm 算法。推理 A8W8模型时使用此 FLAG 会获得更优的性能
88

9-
- `FLAGS_cublaslt_device_best_config`:指向性能最优的int8 gemm配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/generation/test_tune_cublaslt_gemm.py`产出,该脚本会自动搜索当前输入大小下cuBLASLt提供的最优gemm配置并将结果记录下来。
9+
10+
- `FLAGS_cublaslt_device_best_config`:在 FLAGS_enable_blaslt_global_search 设为1的前提下,使用`FLAGS_cublaslt_device_best_config`来指定离线调优出的 int8 gemm 配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/utils/tune_cublaslt_int8_gemm.py`产出,该脚本会自动搜索当前输入大小下 cuBLASLt 提供的最优 gemm 配置并将结果记录下来,需要注意的是不同的 CUDA 版本需要分别 tune。推理 A8W8模型并且 FLAGS_enable_blaslt_global_search 设为1时使用此 FLAG 会获得更优的性能。
1011

1112
**GQA 优化**
1213

13-
- `FLAGS_use_xqa_optim`gpa是否开启xqa优化,默认值为0,表示不开启。gqa模型(如llama3/3.1、qwen2)设为1性能会更好。
14+
- `FLAGS_use_xqa_optim`gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。
1415

1516
**显存优化**
1617

17-
- `FLAGS_allocator_strategy`:显存管理策略,默认值为 `auto_growth`。可优先设为`naive_best_fit`,若显存oom可配置为`auto_growth`
18-
19-
- `FLAGS_fraction_of_gpu_memory_to_use`:GPU显存使用率,默认值为0.9。设置为0.9即可。
18+
- `FLAGS_fraction_of_gpu_memory_to_use`:GPU 显存使用率,默认值为0.9。设置为0.9即可。
2019

2120
- `FLAGS_gemm_use_half_precision_compute_type`:是否使用半精度浮点数计算,默认值为0。设置为0即可。

llm/docs/predict/llama.md

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,43 @@
11
# LLaMA
22

3-
本文档展示了如何在 PaddleNLP中构建和运行[LLaMA](https://llama.meta.com/) 系列大模型。
3+
本文档展示了如何在 PaddleNLP 中构建和运行[LLaMA](https://llama.meta.com/) 系列大模型。
44

55
## 模型介绍
66

77
* LLaMA 系列大模型是由 Meta AI 发布的一个开放且高效的大型基础语言模型。
88

9-
* [Llama 2](https://llama.meta.com/llama2/):2023年7月,Meta发布了Llama 2系列,有7B、13B、34B和70B四个版本。该版本实现了开源商用,降低了初创公司创建类似ChatGPT聊天机器人的成本
9+
* [Llama 2](https://llama.meta.com/llama2/):2023年7月,Meta 发布了 Llama 2系列,有7B、13B、34B 和70B 四个版本。该版本实现了开源商用,降低了初创公司创建类似 ChatGPT 聊天机器人的成本
1010

11-
* [Llama 3](https://llama.meta.com/):2024年4月19日,Meta推出了Llama 3系列,包括8B和70B两个版本,400B的Llama-3还在训练中。该版本在多个基准测试中取得了全面进步,性能优异。
11+
* [Llama 3](https://llama.meta.com/):2024年4月19日,Meta 推出了 Llama 3系列,包括8B 和70B 两个版本,400B 的 Llama-3还在训练中。该版本在多个基准测试中取得了全面进步,性能优异。
1212

13-
* [Llama 3.1](https://llama.meta.com/):2024年7月23日,Meta发布了Llama 3.1 8B、70B、405B模型,进一步提升了模型的性能和效率。
13+
* [Llama 3.1](https://llama.meta.com/):2024年7月23日,Meta 发布了 Llama 3.1 8B、70B、405B 模型,进一步提升了模型的性能和效率。
1414

15-
## 模型支持
15+
## 已验证的模型
1616

17-
| Model |
18-
| :----------------------------: |
19-
| meta-llama/Llama-2-7b(-chat) |
20-
| meta-llama/Llama-2-13b(-chat) |
21-
| meta-llama/Llama-2-70b(-chat) |
22-
| meta-llama/Meta-Llama-3-8B(-Instruct) |
23-
| meta-llama/Meta-Llama-3-70B(-Instruct) |
24-
| meta-llama/Meta-Llama-3.1-8B(-Instruct) |
25-
| meta-llama/Meta-Llama-3.1-70B(-Instruct) |
26-
| meta-llama/Meta-Llama-3.1-405B(-Instruct) |
17+
|Model|
18+
|:-|
19+
|meta-llama/Llama-2-7b-chat|
20+
|meta-llama/Llama-2-13b-chat|
21+
|meta-llama/Llama-2-70b-chat|
22+
|meta-llama/Meta-Llama-3-8B-Instruct|
23+
|meta-llama/Meta-Llama-3-70B-Instruct|
24+
|meta-llama/Meta-Llama-3.1-8B-Instruct|
25+
|meta-llama/Meta-Llama-3.1-70B-Instruct|
26+
|meta-llama/Meta-Llama-3.1-405B-Instruct|
27+
28+
## 已验证的预量化模型
29+
30+
|Model|
31+
|:-|
32+
|meta-llama/Meta-Llama-3-8B-Instruct-A8W8C8|
33+
|meta-llama/Meta-Llama-3-8B-Instruct-A8W8-FP8|
34+
|meta-llama/Meta-Llama-3.1-8B-Instruct-A8W8C8|
35+
|meta-llama/Meta-Llama-3.1-8B-Instruct-A8W8-FP8|
2736

2837

2938
## 模型推理
3039

31-
以meta-llama/Meta-Llama-3-8B-Instruct单卡和meta-llama/Meta-Llama-3.1-405B-Instruct多卡为例
40+
以 meta-llama/Meta-Llama-3-8B-Instruct 单卡和 meta-llama/Meta-Llama-3.1-405B-Instruct 多卡为例
3241

3342
BF16推理
3443

@@ -57,7 +66,7 @@ python predict/export_model.py --model_name_or_path meta-llama/Meta-Llama-3-8B-I
5766
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type weight_only_int8
5867
```
5968

60-
下面量化推理所需要的模型需要根据[大模型量化教程](../quantization.md)产出。
69+
下面量化推理所需要的模型需要根据[大模型量化教程](../quantization.md)产出,如 checkpoints/llama_ptq_ckpts,或者使用所提供的预先量化好的模型,如 meta-llama/Meta-Llama-3-8B-Instruct-A8W8C8
6170

6271
INT8-A8W8推理
6372

@@ -76,10 +85,10 @@ INT8-A8W8C8推理
7685

7786
```shell
7887
# 动态图推理
79-
python predict/predictor.py --model_name_or_path checkpoints/llama_ptq_ckpts --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
88+
python predict/predictor.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct-A8W8C8 --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
8089

8190
# 动转静导出模型
82-
python predict/export_model.py --model_name_or_path checkpoints/llama_ptq_ckpts --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
91+
python predict/export_model.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct-A8W8C8 --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
8392

8493
# 静态图推理
8594
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
@@ -88,10 +97,10 @@ python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype
8897
FP8-A8W8推理
8998
```shell
9099
# 动态图推理
91-
python predict/predictor.py --model_name_or_path checkpoints/llama_ptq_ckpts --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
100+
python predict/predictor.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct-A8W8-FP8 --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
92101

93102
# 动转静导出模型
94-
python predict/export_model.py --model_name_or_path checkpoints/llama_ptq_ckpts --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
103+
python predict/export_model.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct-A8W8-FP8 --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
95104

96105
# 静态图推理
97106
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
@@ -108,13 +117,12 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-405B-Instru
108117
generation_config = GenerationConfig.from_pretrained("meta-llama/Meta-Llama-3.1-405B-Instruct")
109118
```
110119

111-
这里通过--use_fake_parameter使用fake parameters,如需要推理正确的量化模型,请自行参考[大模型量化教程](../quantization.md)进行量化。
120+
这里通过--use_fake_parameter 使用 fake parameters,如需要推理正确的量化模型,请自行参考[大模型量化教程](../quantization.md)进行量化。
112121

113122
```shell
114123
# 导出模型 (可在predict/export_model.py中设置paddle.set_device("cpu"),通过内存导出模型)
115124
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" predict/export_model.py --model_name_or_path meta-llama/Meta-Llama-3.1-405B-Instruct --output_path /path/to/a8w8c8_tp8 --inference_model 1 --block_attn 1 --dtype bfloat16 --quant_type a8w8 --cachekv_int8_type static --use_fake_parameter 1
116125

117126
# 推理
118-
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" predict/predictor.py --model_name_or_path /path/to/a8w8c8_tp8 --mode static --inference_model 1 --block_attn 1 --dtype bfloat16 --quant_type a8w8 --cachekv_int8_type static
127+
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" predict/predictor.py --model_name_or_path /path/to/a8w8c8_tp8 --mode static --inference_model 1 --block_attn 1 --dtype bfloat16 --quant_type a8w8 --cachekv_int8_type static
119128
```
120-

llm/docs/predict/mixtral.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
# Mixtral
22

3-
本文档展示了如何在 PaddleNLP中构建和运行 [Mxtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) 模型。
3+
本文档展示了如何在 PaddleNLP 中构建和运行 [Mxtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) 模型。
44

55
## 模型介绍
66

77

8-
* [Mistral系列](https://arxiv.org/abs/2310.06825) 是Mistral AI研发的基座大模型,使用了分组查询注意力和滑动窗口注意力机制来提高模型性能表现和推理速度,包括7B不同规模的Base和Instruct模型
9-
* [Mixtral系列](https://arxiv.org/abs/2401.04088) 是Mistral AI采用MoE(Mixture of Experts)架构设计的基座大模型,在大多数基准测试中优于同级别的llama模型,MoE结合了多个专家模型的优势来解决问题,在推理中仅需激活少量专家就可以达到非常好的效果,相比于传统大模型减少了较多的计算量;目前开源模型包括8x7B和8x22B两种不同规模的Base和Instruct模型
8+
* [Mistral 系列](https://arxiv.org/abs/2310.06825) 是 Mistral AI 研发的基座大模型,使用了分组查询注意力和滑动窗口注意力机制来提高模型性能表现和推理速度,包括7B 不同规模的 Base 和 Instruct 模型
9+
* [Mixtral 系列](https://arxiv.org/abs/2401.04088) 是 Mistral AI 采用 MoE(Mixture of Experts)架构设计的基座大模型,在大多数基准测试中优于同级别的 llama 模型,MoE 结合了多个专家模型的优势来解决问题,在推理中仅需激活少量专家就可以达到非常好的效果,相比于传统大模型减少了较多的计算量;目前开源模型包括8x7B 和8x22B 两种不同规模的 Base 和 Instruct 模型
1010

11-
## 模型支持
11+
## 已验证的模型
1212

13-
| Model |
14-
| :-----------------------------: |
15-
| mistralai/Mixtral-8x7B-v0.1(-Instruct) |
13+
|Model|
14+
|:-|
15+
|mistralai/Mixtral-8x7B-v0.1-Instruct|
1616

1717

1818
## 模型推理
1919

20-
下面以Mixtral-8x7B-Instruct-v0.1两卡为例介绍整体推理流程。
20+
下面以 Mixtral-8x7B-Instruct-v0.1两卡为例介绍整体推理流程。
2121

2222
BF16推理
2323

@@ -97,4 +97,4 @@ python -m paddle.distributed.launch \
9797
--mode "static" \
9898
--inference_model \
9999
--block_attn
100-
```
100+
```

llm/docs/predict/qwen.md

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
# Qwen
22

3-
本文档展示了如何在 PaddleNLP中构建和运行[Qwen](https://huggingface.co/Qwen) 系列大模型。
3+
本文档展示了如何在 PaddleNLP 中构建和运行[Qwen](https://huggingface.co/Qwen) 系列大模型。
44

55
## 模型介绍
66

7-
* [通义千问(Qwen)](https://arxiv.org/abs/2205.01068) 是阿里云研发的通义千问大模型系列的模型, 包括 Qwen-1.8B、Qwen-7B、Qwen-14B和Qwen-72B等4个规模。Qwen 是基于 Transformer 的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样,覆盖广泛,包括大量网络文本、专业书籍、代码等。
7+
* [通义千问(Qwen)](https://arxiv.org/abs/2205.01068) 是阿里云研发的通义千问大模型系列的模型, 包括 Qwen-1.8B、Qwen-7B、Qwen-14B 和 Qwen-72B 等4个规模。Qwen 是基于 Transformer 的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样,覆盖广泛,包括大量网络文本、专业书籍、代码等。
88

9-
* [通义千问(Qwen1.5)](https://qwenlm.github.io/blog/qwen1.5/) 是阿里云研发的通义千问系列模型升级版。Qwen1.5包括0.5B、1.8B、4B、7B、14B、32B、72B和110B共计8个不同规模的Base和Chat模型
9+
* [通义千问(Qwen1.5)](https://qwenlm.github.io/blog/qwen1.5/) 是阿里云研发的通义千问系列模型升级版。Qwen1.5包括0.5B、1.8B、4B、7B、14B、32B、72B 和110B 共计8个不同规模的 Base 和 Chat 模型
1010

11-
* [通义千问(Qwen2)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen2包括Qwen2-0.5B、Qwen2-1.5B、Qwen2-7B、Qwen2-57B-A14B 以及Qwen2-72B 共计5个不同规模的 Base 和 Instruct 模型。
11+
* [通义千问(Qwen2)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen2包括 Qwen2-0.5B、Qwen2-1.5B、Qwen2-7B、Qwen2-57B-A14B 以及 Qwen2-72B 共计5个不同规模的 Base 和 Instruct 模型。
1212

13-
* [通义千问(Qwen-MoE)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen-MoE包括Qwen1.5-MoE-A2.7B 以及 Qwen2-57B-A14B 共计2个不同规模的 Base、Chat 和 Instruct 模型。
13+
* [通义千问(Qwen-MoE)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen-MoE 包括 Qwen1.5-MoE-A2.7B 以及 Qwen2-57B-A14B 共计2个不同规模的 Base、Chat 和 Instruct 模型。
1414

15-
## 模型支持
15+
## 已验证的模型
1616

17-
| Model |
18-
| :----------------------------: |
19-
| Qwen/Qwen2-0.5B(-Instruct) |
20-
| Qwen/Qwen2-1.5B(-Instruct) |
21-
| Qwen/Qwen2-7B(-Instruct) |
22-
| Qwen/Qwen1.5-MoE-A2.7B(-Chat) |
17+
|Model|
18+
|:-|
19+
|Qwen/Qwen2-0.5B-Instruct|
20+
|Qwen/Qwen2-1.5B-Instruct|
21+
|Qwen/Qwen2-7B-Instruct|
22+
|Qwen/Qwen1.5-MoE-A2.7B-Chat|
23+
|Qwen/Qwen2-57B-A14B-Instruct|
2324

25+
## 已验证的预量化模型
26+
27+
|Model|
28+
|:-|
29+
|Qwen/Qwen2-1.5B-Instruct-A8W8C8|
30+
|Qwen/Qwen2-1.5B-Instruct-A8W8-FP8|
31+
|Qwen/Qwen2-7B-Instruct-A8W8C8|
32+
|Qwen/Qwen2-7B-Instruct-A8W8-FP8|
2433

2534
## 模型推理
2635

27-
以Qwen/Qwen2-1.5B-Instruct为例
36+
以 Qwen/Qwen2-1.5B-Instruct 为例
2837

2938
BF16推理
3039

@@ -53,7 +62,7 @@ python predict/export_model.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --o
5362
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type weight_only_int8
5463
```
5564

56-
下面量化推理所需要的模型需要根据[大模型量化教程](../quantization.md)产出。
65+
下面量化推理所需要的模型需要根据[大模型量化教程](../quantization.md)产出,如 checkpoints/qwen_ptq_ckpts,或者使用所提供的预先量化好的模型,如 Qwen/Qwen2-1.5B-Instruct-A8W8C8
5766

5867
INT8-A8W8推理
5968

@@ -72,10 +81,10 @@ INT8-A8W8C8推理
7281

7382
```shell
7483
# 动态图推理
75-
python predict/predictor.py --model_name_or_path checkpoints/qwen_ptq_ckpts --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
84+
python predict/predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct-A8W8C8 --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
7685

7786
# 动转静导出模型
78-
python predict/export_model.py --model_name_or_path checkpoints/qwen_ptq_ckpts --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
87+
python predict/export_model.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct-A8W8C8 --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
7988

8089
# 静态图推理
8190
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type a8w8 --cachekv_int8_type static
@@ -84,10 +93,10 @@ python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype
8493
FP8-A8W8推理
8594
```shell
8695
# 动态图推理
87-
python predict/predictor.py --model_name_or_path checkpoints/qwen_ptq_ckpts --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
96+
python predict/predictor.py --model_name_or_path Qwen/Qwen2-7B-Instruct-A8W8-FP8 --dtype bfloat16 --mode dynamic --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
8897

8998
# 动转静导出模型
90-
python predict/export_model.py --model_name_or_path checkpoints/qwen_ptq_ckpts --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
99+
python predict/export_model.py --model_name_or_path Qwen/Qwen2-7B-Instruct-A8W8-FP8 --output_path /path/to/exported_model --dtype bfloat16 --inference_model 1 --block_attn 1 --quant_type a8w8_fp8
91100

92101
# 静态图推理
93102
python predict/predictor.py --model_name_or_path /path/to/exported_model --dtype bfloat16 --mode static --inference_model 1 --block_attn 1 --quant_type a8w8_fp8

llm/predict/predictor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -886,12 +886,8 @@ def init_model_inputs(self, config: PredictorArgument):
886886
]
887887
)
888888
# self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated!
889-
src_mask = (
890-
alibi_encoder + (1 - src_mask) * paddle.finfo(self.dtype).min
891-
).cast(self.dtype)
892-
tgt_mask = (
893-
alibi_decoder + (1 - tgt_mask) * paddle.finfo(self.dtype).min
894-
).cast(self.dtype)
889+
src_mask = (alibi_encoder + (1 - src_mask) * paddle.finfo(self.dtype).min).cast(self.dtype)
890+
tgt_mask = (alibi_decoder + (1 - tgt_mask) * paddle.finfo(self.dtype).min).cast(self.dtype)
895891
self.model_inputs["rope_emb"] = paddle.concat([src_mask.reshape([-1]), tgt_mask.reshape([-1])])
896892

897893
def _preprocess(self, input_text: list[str]):

paddlenlp/experimental/transformers/bloom/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def forward(
293293

294294
@paddle.no_grad()
295295
def set_state_dict(self, state_dict, use_structured_name=True):
296+
self.transformer_block.init_weight()
296297
for k, v in state_dict.items():
297298
if k.find("word_embeddings.weight") >= 0:
298299
self.word_embeddings.weight.set_value(paddle.to_tensor(v))

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def forward(
377377

378378
@paddle.no_grad()
379379
def set_state_dict(self, state_dict, use_structured_name=True):
380+
self.transformer_block.init_weight()
380381
dtype = paddle.get_default_dtype()
381382
config = self.config
382383
embed_dim = config.hidden_size

paddlenlp/experimental/transformers/chatglm_v2/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ def forward(
290290

291291
@paddle.no_grad()
292292
def set_state_dict(self, state_dict):
293+
self.transformer_block.init_weight()
294+
293295
# find the real name.
294296
def key(name):
295297
result_list = []

0 commit comments

Comments
 (0)