Skip to content

Commit fbd16a1

Browse files
Add RingFlashAttention for context parallel
1 parent ae0bea9 commit fbd16a1

File tree

10 files changed

+848
-23
lines changed

10 files changed

+848
-23
lines changed

csrc/generation/flash_attn_bwd.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
#include <iostream>
17+
#include <vector>
18+
19+
using paddle::Tensor;
20+
21+
namespace paddle {
22+
namespace experimental {
23+
24+
PADDLE_API void flash_attn_grad(const Tensor& q,
25+
const Tensor& k,
26+
const Tensor& v,
27+
const Tensor& out,
28+
const Tensor& softmax_lse,
29+
const Tensor& seed_offset,
30+
const paddle::optional<Tensor> &attn_mask,
31+
const Tensor& out_grad,
32+
float dropout,
33+
bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad);
34+
35+
}
36+
} // namespace paddle
37+
38+
39+
40+
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
41+
const Tensor &k,
42+
const Tensor &v,
43+
const Tensor &out,
44+
const Tensor &softmax_lse,
45+
const Tensor &seed_offset,
46+
const paddle::optional<Tensor> &attn_mask,
47+
const Tensor &out_grad,
48+
float dropout,
49+
bool causal);
50+
51+
52+
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
53+
const Tensor &k,
54+
const Tensor &v,
55+
const Tensor &out,
56+
const Tensor &softmax_lse,
57+
const Tensor &seed_offset,
58+
const paddle::optional<Tensor> &attn_mask,
59+
const Tensor &out_grad,
60+
float dropout,
61+
bool causal){
62+
std::vector<Tensor> res(3);
63+
paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask,
64+
out_grad, dropout, causal, &res[0], &res[1],
65+
&res[2]);
66+
return res;
67+
}
68+
69+
70+
71+
std::vector<paddle::DataType> SRFlashAttnBwdDtype(paddle::DataType q_dtype,
72+
paddle::DataType k_dtype,
73+
paddle::DataType v_dtype) {
74+
return {q_dtype, k_dtype, v_dtype};
75+
76+
}
77+
78+
79+
std::vector<std::vector<int64_t>> SRFlashAttnBwdInferShape(
80+
std::vector<int64_t> q_shape, std::vector<int64_t> k_shape,
81+
std::vector<int64_t> v_shape) {
82+
return {q_shape, k_shape, v_shape};
83+
}
84+
85+
86+
PD_BUILD_OP(flash_attn_bwd)
87+
.Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"})
88+
.Outputs({"q_grad", "k_grad", "v_grad"})
89+
.Attrs({"dropout: float", "causal: bool"})
90+
.SetKernelFn(PD_KERNEL(SRFlashAttnBwd))
91+
.SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape))
92+
.SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype));

csrc/setup_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_gencode_flags():
7878
"./generation/step.cu",
7979
"./generation/quant_int8.cu",
8080
"./generation/dequant_int8.cu",
81+
"./generation/flash_attn_bwd.cc",
8182
],
8283
extra_compile_args={
8384
"cxx": ["-O3"],

llm/llama/run_trainer_tp2cp2.sh

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
set -x
17+
unset CUDA_VISIBLE_DEVICES
18+
19+
rm -rf log
20+
rm -rf output
21+
22+
unset PADDLE_ELASTIC_JOB_ID
23+
unset PADDLE_TRAINER_ENDPOINTS
24+
unset DISTRIBUTED_TRAINER_ENDPOINTS
25+
unset FLAGS_START_PORT
26+
unset PADDLE_ELASTIC_TIMEOUT
27+
28+
# export FLAGS_embedding_deterministic=1
29+
# export FLAGS_cudnn_deterministic=1
30+
# export FLAGS_flash_attn_version=v1
31+
# export USE_FAST_LN=0
32+
33+
34+
max_seq_length=1024
35+
36+
master=127.0.0.1
37+
port=36677
38+
39+
max_steps=10000
40+
log_dir=seq_${max_seq_length}_log
41+
echo "log_dir:${log_dir}"
42+
rm -rf $log_dir
43+
44+
export PYTHONPATH=../../:$PYTHONPATH
45+
python -u -m paddle.distributed.launch \
46+
--master $master:$port \
47+
--gpus "3,4,5,7" \
48+
--log_dir "./$log_dir" \
49+
run_pretrain.py \
50+
--model_name_or_path "facebook/llama-7b" \
51+
--tokenizer_name_or_path "facebook/llama-7b" \
52+
--input_dir "./data" \
53+
--output_dir "./output" \
54+
--split 949,50,1 \
55+
--max_seq_length $max_seq_length \
56+
--per_device_train_batch_size 1 \
57+
--gradient_accumulation_steps 4 \
58+
--per_device_eval_batch_size 4 \
59+
--bf16 \
60+
--fp16_opt_level "O2" \
61+
--use_flash_attention 1 \
62+
--virtual_pp_degree 1 \
63+
--pp_recompute_interval 1 \
64+
--learning_rate 0.00001 \
65+
--min_learning_rate 0.000001 \
66+
--max_steps $max_steps \
67+
--weight_decay 0.01 \
68+
--warmup_ratio 0.01 \
69+
--max_grad_norm 1.0 \
70+
--logging_steps 1 \
71+
--dataloader_num_workers 1 \
72+
--eval_steps 1001 \
73+
--disable_tqdm true \
74+
--continue_training 0 \
75+
--do_train \
76+
--device "gpu" \
77+
--enable_linear_fused_grad_add false \
78+
--recompute_use_reentrant true \
79+
--data_cache "./data_cache" \
80+
--pipeline_parallel_degree 1 \
81+
--cp_parallel_degree 2 \
82+
--tensor_parallel_degree 2 \
83+
--sequence_parallel false \
84+
--skip_profile_timer true \
85+
--amp_master_grad \
86+
--report_to "visualdl" \
87+
--logging_dir "./visualdl_log" \
88+
--save_steps 2000000 \

llm/run_pretrain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,11 +485,15 @@ def main():
485485
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
486486

487487
config.sep_parallel_degree = training_args.sep_parallel_degree
488+
config.cp_parallel_degree = training_args.cp_parallel_degree
488489
if config.sequence_parallel:
489490
assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel."
490491
assert (
491492
config.num_attention_heads % config.sep_parallel_degree == 0
492493
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
494+
assert (
495+
config.seq_length % config.cp_parallel_degree == 0
496+
), f"seq_length:{config.seq_length} must be divisible by cp_parallel_degree {config.cp_parallel_degree}"
493497

494498
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
495499
try:

paddlenlp/trainer/trainer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from ..quantization.quantization_linear import QuantizationLinear
8282
except:
8383
QuantizationLinear = None
84+
from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
8485
from ..transformers.model_utils import (
8586
PretrainedModel,
8687
_add_variant,
@@ -763,6 +764,8 @@ def train(
763764
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size
764765
if self.args.sep_parallel_degree > 0:
765766
trainable_numel = trainable_numel // self.args.sep_parallel_degree
767+
if self.args.cp_parallel_degree > 0:
768+
trainable_numel = trainable_numel // self.args.cp_parallel_degree
766769
# the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited
767770
# so, the trainable numel is a little bigger than real.
768771
logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
@@ -897,6 +900,8 @@ def _inner_training_loop(
897900
for step, inputs in enumerate(epoch_iterator):
898901
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
899902
inputs = split_inputs_sequence_dim(inputs)
903+
if self.args.use_hybrid_parallel and self.args.cp_parallel_degree > 1:
904+
inputs = split_inputs_sequence_dim_load_balance(inputs)
900905
self.timers and self.timers("read-data").stop()
901906
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
902907
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)
@@ -1006,7 +1011,11 @@ def _inner_training_loop(
10061011
assert reshard_util.is_sharding_opt(self.optimizer)
10071012
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081013

1009-
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1014+
if (
1015+
self.optimizer._dp_enable
1016+
or getattr(self.optimizer, "_sep_enable", False)
1017+
or getattr(self.optimizer, "_cp_enable", False)
1018+
):
10101019
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
10111020

10121021
self.timers and self.timers("all-reduce").stop()
@@ -1733,6 +1742,7 @@ def _wrap_model(self, model, training=True):
17331742
in_sharding_parallel_mode = self.sharding is not None
17341743
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
17351744
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
1745+
in_cp_parallel_mode = self.args.cp_parallel_degree > 1
17361746

17371747
# Multi-gpu training
17381748
if (
@@ -1743,6 +1753,7 @@ def _wrap_model(self, model, training=True):
17431753
or in_sharding_parallel_mode
17441754
or in_tensor_parallel_mode
17451755
or in_sep_parallel_mode
1756+
or in_cp_parallel_mode
17461757
)
17471758
):
17481759
model = paddle.DataParallel(model)
@@ -1870,7 +1881,7 @@ def get_expected_keys(inputs, keys):
18701881
if (
18711882
not in_pipeline_parallel_mode
18721883
and not in_sharding_parallel_mode
1873-
and (in_tensor_parallel_mode or in_sep_parallel_mode)
1884+
and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode)
18741885
):
18751886
if self.args.amp_master_grad:
18761887
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use

0 commit comments

Comments
 (0)