Skip to content

Commit ed80c13

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP into qwen2-fp8
2 parents 1bde9b8 + c28caf7 commit ed80c13

File tree

65 files changed

+2350
-94
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2350
-94
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,7 @@ FETCH_HEAD
125125
.vscode
126126
./ppdiffusers/ppdiffusers/version.py
127127

128+
# third party
129+
csrc/gpu/cutlass_kernels/cutlass
128130
dataset/
129-
output/
131+
output/

csrc/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,16 @@ pip install -r requirements.txt
1313
```shell
1414
python setup_cuda.py install
1515
```
16+
17+
### 手动安装 Cutlass 库
18+
1. 访问 Cutlass 仓库: [NVIDIA/cutlass](https://github.com/NVIDIA/cutlass)
19+
20+
2. 拉取代码:
21+
git clone -b v3.5.0 --single-branch https://github.com/NVIDIA/cutlass.git
22+
23+
3. 将下载的 `cutlass` 目录放在 `csrc/gpu/cutlass_kernels/cutlass`
24+
25+
4. 重新编译 Cuda 算子
26+
```shell
27+
python setup_cuda.py install
28+
```
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "cutlass/half.h"
17+
#include "cutlass/bfloat16.h"
18+
#include "paddle/extension.h"
19+
20+
template <paddle::DataType D>
21+
class CutlassDtypeTraits;
22+
23+
template <>
24+
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
25+
public:
26+
typedef float DataType;
27+
typedef float data_t;
28+
};
29+
30+
template <>
31+
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
32+
public:
33+
typedef cutlass::half_t DataType;
34+
typedef paddle::float16 data_t;
35+
};
36+
37+
template <>
38+
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
39+
public:
40+
typedef cutlass::bfloat16_t DataType;
41+
typedef paddle::bfloat16 data_t;
42+
};
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)