Skip to content

[INFER] update tune_cublaslt_gemm op and fix some bugs #9222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 54 additions & 50 deletions csrc/gpu/tune_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ limitations under the License. */

#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <list>
#include <vector>
#include <iomanip>

#include "helper.h"

Expand Down Expand Up @@ -105,6 +105,13 @@ static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a,
return (algo_para_a.time < algo_para_b.time);
}

// 获取当前 GPU 的剩余显存大小(以字节为单位)
size_t get_remaining_memory() {
size_t free, total;
CUDA_CHECK(cudaMemGetInfo(&free, &total));
return free;
}

template <typename InT, typename OutT, typename ScaleT = OutT>
static void TestMatmulRun(cublasLtHandle_t ltHandle,
cublasLtMatmulDesc_t matmulDesc,
Expand All @@ -122,7 +129,10 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) {

auto remainingMemorySize = 0.95 * get_remaining_memory();
if (algoStatus == CUBLAS_STATUS_SUCCESS &&
remainingMemorySize > heurResult.workspaceSize) {
ScaleT alpha = static_cast<ScaleT>(1), beta = static_cast<ScaleT>(0);
void* workSpace;
CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize));
Expand Down Expand Up @@ -166,8 +176,13 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
}
CUDA_CHECK(cudaFree(workSpace));
} else {
std::cerr << "not enough workspace! current workspace is "
<< heurResult.workspaceSize;
std::cerr << "Not enough workspace! Required "
<< static_cast<double>(heurResult.workspaceSize) / 1024.0 /
1024.0 / 1024.0
<< " GiB" << ", But remaining "
<< static_cast<double>(remainingMemorySize) / 1024.0 / 1024.0 /
1024.0
<< " GiB" << std::endl;
perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
}
}
Expand Down Expand Up @@ -442,7 +457,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) {
std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile
<< " stages " << algos[i].stages << " splitK_val "
<< algos[i].splitK_val;
<< algos[i].splitK_val << std::endl;
algos[i].time = std::numeric_limits<float>::max();
std::cerr << " TestMatmulRun with status " << perfResults[i].status
<< std::endl;
Expand All @@ -467,7 +482,7 @@ class DevContext {};
class CPUContext : public DevContext {};

class CUBLASLTContext : public DevContext {
public:
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }

cublasLtHandle_t handle;
Expand Down Expand Up @@ -709,64 +724,51 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
CUDA_CHECK(cudaFree(workSpace));
}

void TuneCublasltGemm(const paddle::Tensor& M,
const paddle::Tensor& K,
void TuneCublasltGemm(const paddle::Tensor& K,
const paddle::Tensor& N,
const int M_start,
const int M_end,
const std::string& dtype,
bool is_test,
bool is_read_from_file,
const bool is_test,
const bool is_read_from_file,
const std::string& path) {
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
// is_read_from_file
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
assert(M_end >= M_start);
assert(M_start >= 1);
assert(K.dims().size() == 1 && N.dims().size() == 1);
assert(is_test != is_read_from_file);

auto M_cpu = M.copy_to(paddle::CPUPlace(), false);
auto K_cpu = K.copy_to(paddle::CPUPlace(), false);
auto N_cpu = N.copy_to(paddle::CPUPlace(), false);
int64_t* M_data = M_cpu.data<int64_t>();
int64_t* K_data = K_cpu.data<int64_t>();
int64_t* N_data = N_cpu.data<int64_t>();

int M_size = M.numel();
int K_size = K.numel();
int N_size = N.numel();
assert(K_size == N_size);

int m_data = (int)M_data[0];
assert(m_data > 0);

std::vector<int> mm;

int m = 1, step = 1;
while (m <= m_data) {
mm.push_back(m);
m += step;

int m = M_start, step = 1;
while (m <= M_end) {
// update step
switch (m) {
case 4:
step = 4;
break;
case 16:
step = 16;
break;
case 64:
step = 32;
break;
case 256:
step = 64;
break;
case 512:
step = 128;
break;
case 1024:
step = 1024;
break;
case 8192:
step = 4096;
break;
if (m >= 8192) {
step = 4096;
} else if (m >= 1024) {
step = 1024;
} else if (m >= 512) {
step = 128;
} else if (m >= 256) {
step = 64;
} else if (m >= 64) {
step = 32;
} else if (m >= 16) {
step = 16;
} else if (m >= 4) {
step = 4;
} else {
step = 1;
}
mm.push_back(m);
m += step;
}

for (int j = 0; j < mm.size(); j++) {
Expand All @@ -792,16 +794,18 @@ void TuneCublasltGemm(const paddle::Tensor& M,
path);
} else {
// other dtype
std::cout << "Not currently supported" << std::endl;
throw std::runtime_error(dtype + "not currently supported");
}
}
}
}

PD_BUILD_OP(tune_cublaslt_gemm)
.Inputs({"M", "K", "N"})
.Inputs({"K", "N"})
.Outputs({})
.Attrs({"dtype: std::string",
.Attrs({"M_start: int",
"M_end: int",
"dtype: std::string",
"is_test: bool",
"is_read_from_file: bool",
"path: std::string"})
Expand Down
11 changes: 8 additions & 3 deletions csrc/utils/tune_cublaslt_int8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import paddle
from paddlenlp_ops import tune_cublaslt_gemm

M_tensor = paddle.to_tensor([32768])
M_start = 1
M_end = 32768

# llama3.1-8b
k1 = [4096, 4096, 4096, 14336]
Expand All @@ -36,7 +37,11 @@
K_tensor = paddle.to_tensor(k1 + k2 + k3 + k4)
N_tensor = paddle.to_tensor(n1 + n2 + n3 + n4)

Dtype = "int8"
Path = "./cublaslt_gemm_search.csv"

tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path)
tune_cublaslt_gemm(K_tensor, N_tensor, M_start, M_end, "int8", True, False, Path)

# shape 计算公式
# [qkv, out_linear, ffn1, ffn2]
# k = [hidden_size, hidden_size, hidden_size, intermediate_size//mp_size]
# n = [((num_attention_heads//mp_size)+2*(num_key_value_heads//mp_size))*(hidden_size//num_attention_heads), hidden_size, 2*(intermediate_size//mp_size), hidden_size]
18 changes: 14 additions & 4 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ def set_state_dict(self, state_dict):
ffn_hidden_size=self.intermediate_size,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
concat_qkv=True,
concat_ffn1=True,
)
self.transformer_block.weight_scales = weight_scales_loader.scale
self.transformer_block.act_scales = act_scale_loader.scale
Expand Down Expand Up @@ -1097,16 +1099,24 @@ def set_state_dict(self, state_dict):
dtype=paddle.get_default_dtype(),
)
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
paddle.get_default_dtype()
)
)

if self.shift:
Expand Down
16 changes: 12 additions & 4 deletions paddlenlp/experimental/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,16 +716,24 @@ def set_state_dict(self, state_dict):
if "a8w8" in self.quant_type:
if self.shift_smooth_all_linears:
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)

if self.shift:
Expand Down
18 changes: 14 additions & 4 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def set_state_dict(self, state_dict):
ffn_hidden_size=self.intermediate_size,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
concat_qkv=True,
concat_ffn1=True,
)
self.transformer_block.weight_scales = weight_scales_loader.scale
self.transformer_block.act_scales = act_scale_loader.scale
Expand Down Expand Up @@ -704,16 +706,24 @@ def set_state_dict(self, state_dict):
dtype=paddle.get_default_dtype(),
)
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
paddle.get_default_dtype()
)
)

if self.shift:
Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/experimental/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@
ffn_hidden_size,
num_key_value_heads=-1,
mp_size=1,
concat_qkv=False,
concat_ffn1=False,
):
self.key_map = key_map_dict
self.scale = {}
Expand All @@ -126,6 +128,17 @@
n = num_head * dim_head
self.scale[scale_type] = np.full([num_of_layers, n], fill_value=0.1, dtype="float32")

# concat qkv and ffn1
if concat_qkv:
self.scale["qkv_weight_scale"] = np.full(

Check warning on line 133 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L132-L133

Added lines #L132 - L133 were not covered by tests
[num_of_layers, qkv_out_size // mp_size], fill_value=0.1, dtype="float32"
)

if concat_ffn1:
self.scale["ffn1_weight_scale"] = np.full(

Check warning on line 138 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L137-L138

Added lines #L137 - L138 were not covered by tests
[num_of_layers, ffn_hidden_size * 2 // mp_size], fill_value=0.1, dtype="float32"
)


class EmptyCacheScale:
"""
Expand Down
Loading