Skip to content

Commit 156182e

Browse files
authored
[INFER] update tune_cublaslt_gemm op and fix some bugs (#9222)
1 parent f9eb62e commit 156182e

File tree

6 files changed

+115
-65
lines changed

6 files changed

+115
-65
lines changed

csrc/gpu/tune_cublaslt_gemm.cu

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ limitations under the License. */
1818

1919
#include <algorithm>
2020
#include <fstream>
21+
#include <iomanip>
2122
#include <iostream>
2223
#include <limits>
2324
#include <list>
2425
#include <vector>
25-
#include <iomanip>
2626

2727
#include "helper.h"
2828

@@ -105,6 +105,13 @@ static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a,
105105
return (algo_para_a.time < algo_para_b.time);
106106
}
107107

108+
// 获取当前 GPU 的剩余显存大小(以字节为单位)
109+
size_t get_remaining_memory() {
110+
size_t free, total;
111+
CUDA_CHECK(cudaMemGetInfo(&free, &total));
112+
return free;
113+
}
114+
108115
template <typename InT, typename OutT, typename ScaleT = OutT>
109116
static void TestMatmulRun(cublasLtHandle_t ltHandle,
110117
cublasLtMatmulDesc_t matmulDesc,
@@ -122,7 +129,10 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
122129
cublasLtMatmulHeuristicResult_t heurResult;
123130
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
124131
ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
125-
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
132+
133+
auto remainingMemorySize = 0.95 * get_remaining_memory();
134+
if (algoStatus == CUBLAS_STATUS_SUCCESS &&
135+
remainingMemorySize > heurResult.workspaceSize) {
126136
ScaleT alpha = static_cast<ScaleT>(1), beta = static_cast<ScaleT>(0);
127137
void* workSpace;
128138
CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize));
@@ -166,8 +176,13 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
166176
}
167177
CUDA_CHECK(cudaFree(workSpace));
168178
} else {
169-
std::cerr << "not enough workspace! current workspace is "
170-
<< heurResult.workspaceSize;
179+
std::cerr << "Not enough workspace! Required "
180+
<< static_cast<double>(heurResult.workspaceSize) / 1024.0 /
181+
1024.0 / 1024.0
182+
<< " GiB" << ", But remaining "
183+
<< static_cast<double>(remainingMemorySize) / 1024.0 / 1024.0 /
184+
1024.0
185+
<< " GiB" << std::endl;
171186
perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
172187
}
173188
}
@@ -442,7 +457,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
442457
if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) {
443458
std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile
444459
<< " stages " << algos[i].stages << " splitK_val "
445-
<< algos[i].splitK_val;
460+
<< algos[i].splitK_val << std::endl;
446461
algos[i].time = std::numeric_limits<float>::max();
447462
std::cerr << " TestMatmulRun with status " << perfResults[i].status
448463
<< std::endl;
@@ -467,7 +482,7 @@ class DevContext {};
467482
class CPUContext : public DevContext {};
468483

469484
class CUBLASLTContext : public DevContext {
470-
public:
485+
public:
471486
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }
472487

473488
cublasLtHandle_t handle;
@@ -709,64 +724,51 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
709724
CUDA_CHECK(cudaFree(workSpace));
710725
}
711726

712-
void TuneCublasltGemm(const paddle::Tensor& M,
713-
const paddle::Tensor& K,
727+
void TuneCublasltGemm(const paddle::Tensor& K,
714728
const paddle::Tensor& N,
729+
const int M_start,
730+
const int M_end,
715731
const std::string& dtype,
716-
bool is_test,
717-
bool is_read_from_file,
732+
const bool is_test,
733+
const bool is_read_from_file,
718734
const std::string& path) {
719-
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
720-
// is_read_from_file
721-
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
735+
assert(M_end >= M_start);
736+
assert(M_start >= 1);
737+
assert(K.dims().size() == 1 && N.dims().size() == 1);
722738
assert(is_test != is_read_from_file);
723739

724-
auto M_cpu = M.copy_to(paddle::CPUPlace(), false);
725740
auto K_cpu = K.copy_to(paddle::CPUPlace(), false);
726741
auto N_cpu = N.copy_to(paddle::CPUPlace(), false);
727-
int64_t* M_data = M_cpu.data<int64_t>();
728742
int64_t* K_data = K_cpu.data<int64_t>();
729743
int64_t* N_data = N_cpu.data<int64_t>();
730744

731-
int M_size = M.numel();
732745
int K_size = K.numel();
733746
int N_size = N.numel();
734747
assert(K_size == N_size);
735748

736-
int m_data = (int)M_data[0];
737-
assert(m_data > 0);
738-
739749
std::vector<int> mm;
740-
741-
int m = 1, step = 1;
742-
while (m <= m_data) {
743-
mm.push_back(m);
744-
m += step;
745-
750+
int m = M_start, step = 1;
751+
while (m <= M_end) {
746752
// update step
747-
switch (m) {
748-
case 4:
749-
step = 4;
750-
break;
751-
case 16:
752-
step = 16;
753-
break;
754-
case 64:
755-
step = 32;
756-
break;
757-
case 256:
758-
step = 64;
759-
break;
760-
case 512:
761-
step = 128;
762-
break;
763-
case 1024:
764-
step = 1024;
765-
break;
766-
case 8192:
767-
step = 4096;
768-
break;
753+
if (m >= 8192) {
754+
step = 4096;
755+
} else if (m >= 1024) {
756+
step = 1024;
757+
} else if (m >= 512) {
758+
step = 128;
759+
} else if (m >= 256) {
760+
step = 64;
761+
} else if (m >= 64) {
762+
step = 32;
763+
} else if (m >= 16) {
764+
step = 16;
765+
} else if (m >= 4) {
766+
step = 4;
767+
} else {
768+
step = 1;
769769
}
770+
mm.push_back(m);
771+
m += step;
770772
}
771773

772774
for (int j = 0; j < mm.size(); j++) {
@@ -792,16 +794,18 @@ void TuneCublasltGemm(const paddle::Tensor& M,
792794
path);
793795
} else {
794796
// other dtype
795-
std::cout << "Not currently supported" << std::endl;
797+
throw std::runtime_error(dtype + "not currently supported");
796798
}
797799
}
798800
}
799801
}
800802

801803
PD_BUILD_OP(tune_cublaslt_gemm)
802-
.Inputs({"M", "K", "N"})
804+
.Inputs({"K", "N"})
803805
.Outputs({})
804-
.Attrs({"dtype: std::string",
806+
.Attrs({"M_start: int",
807+
"M_end: int",
808+
"dtype: std::string",
805809
"is_test: bool",
806810
"is_read_from_file: bool",
807811
"path: std::string"})

csrc/utils/tune_cublaslt_int8_gemm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import paddle
1616
from paddlenlp_ops import tune_cublaslt_gemm
1717

18-
M_tensor = paddle.to_tensor([32768])
18+
M_start = 1
19+
M_end = 32768
1920

2021
# llama3.1-8b
2122
k1 = [4096, 4096, 4096, 14336]
@@ -36,7 +37,11 @@
3637
K_tensor = paddle.to_tensor(k1 + k2 + k3 + k4)
3738
N_tensor = paddle.to_tensor(n1 + n2 + n3 + n4)
3839

39-
Dtype = "int8"
4040
Path = "./cublaslt_gemm_search.csv"
4141

42-
tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path)
42+
tune_cublaslt_gemm(K_tensor, N_tensor, M_start, M_end, "int8", True, False, Path)
43+
44+
# shape 计算公式
45+
# [qkv, out_linear, ffn1, ffn2]
46+
# k = [hidden_size, hidden_size, hidden_size, intermediate_size//mp_size]
47+
# n = [((num_attention_heads//mp_size)+2*(num_key_value_heads//mp_size))*(hidden_size//num_attention_heads), hidden_size, 2*(intermediate_size//mp_size), hidden_size]

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,8 @@ def set_state_dict(self, state_dict):
876876
ffn_hidden_size=self.intermediate_size,
877877
num_key_value_heads=self.num_key_value_heads,
878878
mp_size=self.config.tensor_parallel_degree,
879+
concat_qkv=True,
880+
concat_ffn1=True,
879881
)
880882
self.transformer_block.weight_scales = weight_scales_loader.scale
881883
self.transformer_block.act_scales = act_scale_loader.scale
@@ -1097,16 +1099,24 @@ def set_state_dict(self, state_dict):
10971099
dtype=paddle.get_default_dtype(),
10981100
)
10991101
self.transformer_block.linear_shifts[idx].set_value(
1100-
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
1102+
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
1103+
paddle.get_default_dtype()
1104+
)
11011105
)
11021106
self.transformer_block.linear_smooths[idx].set_value(
1103-
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
1107+
paddle.to_tensor(
1108+
state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
1109+
).astype(paddle.get_default_dtype())
11041110
)
11051111
self.transformer_block.ffn2_shifts[idx].set_value(
1106-
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)])
1112+
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
1113+
paddle.get_default_dtype()
1114+
)
11071115
)
11081116
self.transformer_block.ffn2_smooths[idx].set_value(
1109-
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
1117+
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
1118+
paddle.get_default_dtype()
1119+
)
11101120
)
11111121

11121122
if self.shift:

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -716,16 +716,24 @@ def set_state_dict(self, state_dict):
716716
if "a8w8" in self.quant_type:
717717
if self.shift_smooth_all_linears:
718718
self.transformer_block.linear_shifts[idx].set_value(
719-
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
719+
paddle.to_tensor(
720+
state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)]
721+
).astype(paddle.get_default_dtype())
720722
)
721723
self.transformer_block.linear_smooths[idx].set_value(
722-
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
724+
paddle.to_tensor(
725+
state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
726+
).astype(paddle.get_default_dtype())
723727
)
724728
self.transformer_block.ffn2_shifts[idx].set_value(
725-
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)])
729+
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
730+
paddle.get_default_dtype()
731+
)
726732
)
727733
self.transformer_block.ffn2_smooths[idx].set_value(
728-
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
734+
paddle.to_tensor(
735+
state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)]
736+
).astype(paddle.get_default_dtype())
729737
)
730738

731739
if self.shift:

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def set_state_dict(self, state_dict):
453453
ffn_hidden_size=self.intermediate_size,
454454
num_key_value_heads=self.num_key_value_heads,
455455
mp_size=self.config.tensor_parallel_degree,
456+
concat_qkv=True,
457+
concat_ffn1=True,
456458
)
457459
self.transformer_block.weight_scales = weight_scales_loader.scale
458460
self.transformer_block.act_scales = act_scale_loader.scale
@@ -704,16 +706,24 @@ def set_state_dict(self, state_dict):
704706
dtype=paddle.get_default_dtype(),
705707
)
706708
self.transformer_block.linear_shifts[idx].set_value(
707-
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
709+
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
710+
paddle.get_default_dtype()
711+
)
708712
)
709713
self.transformer_block.linear_smooths[idx].set_value(
710-
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
714+
paddle.to_tensor(
715+
state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
716+
).astype(paddle.get_default_dtype())
711717
)
712718
self.transformer_block.ffn2_shifts[idx].set_value(
713-
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)])
719+
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
720+
paddle.get_default_dtype()
721+
)
714722
)
715723
self.transformer_block.ffn2_smooths[idx].set_value(
716-
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
724+
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
725+
paddle.get_default_dtype()
726+
)
717727
)
718728

719729
if self.shift:

paddlenlp/experimental/transformers/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def __init__(
108108
ffn_hidden_size,
109109
num_key_value_heads=-1,
110110
mp_size=1,
111+
concat_qkv=False,
112+
concat_ffn1=False,
111113
):
112114
self.key_map = key_map_dict
113115
self.scale = {}
@@ -126,6 +128,17 @@ def __init__(
126128
n = num_head * dim_head
127129
self.scale[scale_type] = np.full([num_of_layers, n], fill_value=0.1, dtype="float32")
128130

131+
# concat qkv and ffn1
132+
if concat_qkv:
133+
self.scale["qkv_weight_scale"] = np.full(
134+
[num_of_layers, qkv_out_size // mp_size], fill_value=0.1, dtype="float32"
135+
)
136+
137+
if concat_ffn1:
138+
self.scale["ffn1_weight_scale"] = np.full(
139+
[num_of_layers, ffn_hidden_size * 2 // mp_size], fill_value=0.1, dtype="float32"
140+
)
141+
129142

130143
class EmptyCacheScale:
131144
"""

0 commit comments

Comments
 (0)