From 63b2be840b47f13a27786db74a7980dbc49eeb8d Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Wed, 8 May 2024 04:20:22 +0000 Subject: [PATCH 1/7] Add RingFlashAttention for context parallel --- csrc/generation/flash_attn_bwd.cc | 92 +++ csrc/setup_cuda.py | 1 + llm/llama/run_trainer_tp2cp2.sh | 88 +++ llm/run_pretrain.py | 4 + paddlenlp/trainer/trainer.py | 15 +- paddlenlp/trainer/training_args.py | 61 +- paddlenlp/transformers/configuration_utils.py | 3 +- .../transformers/context_parallel_utils.py | 101 +++ paddlenlp/transformers/llama/fusion_ops.py | 39 +- paddlenlp/transformers/llama/modeling.py | 32 +- .../transformers/ring_flash_attention.py | 436 ++++++++++++ .../ring_flash_attention_back_up.py | 664 ++++++++++++++++++ 12 files changed, 1516 insertions(+), 20 deletions(-) create mode 100644 csrc/generation/flash_attn_bwd.cc create mode 100644 llm/llama/run_trainer_tp2cp2.sh create mode 100644 paddlenlp/transformers/context_parallel_utils.py create mode 100644 paddlenlp/transformers/ring_flash_attention.py create mode 100644 paddlenlp/transformers/ring_flash_attention_back_up.py diff --git a/csrc/generation/flash_attn_bwd.cc b/csrc/generation/flash_attn_bwd.cc new file mode 100644 index 000000000000..3acd55cbdbe9 --- /dev/null +++ b/csrc/generation/flash_attn_bwd.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include +#include + +using paddle::Tensor; + +namespace paddle { +namespace experimental { + +PADDLE_API void flash_attn_grad(const Tensor& q, + const Tensor& k, + const Tensor& v, + const Tensor& out, + const Tensor& softmax_lse, + const Tensor& seed_offset, + const paddle::optional &attn_mask, + const Tensor& out_grad, + float dropout, + bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); + +} +} // namespace paddle + + + +std::vector SRFlashAttnBwd(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &out, + const Tensor &softmax_lse, + const Tensor &seed_offset, + const paddle::optional &attn_mask, + const Tensor &out_grad, + float dropout, + bool causal); + + +std::vector SRFlashAttnBwd(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &out, + const Tensor &softmax_lse, + const Tensor &seed_offset, + const paddle::optional &attn_mask, + const Tensor &out_grad, + float dropout, + bool causal){ + std::vector res(3); + paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask, + out_grad, dropout, causal, &res[0], &res[1], + &res[2]); + return res; +} + + + +std::vector SRFlashAttnBwdDtype(paddle::DataType q_dtype, + paddle::DataType k_dtype, + paddle::DataType v_dtype) { + return {q_dtype, k_dtype, v_dtype}; + +} + + +std::vector> SRFlashAttnBwdInferShape( + std::vector q_shape, std::vector k_shape, + std::vector v_shape) { + return {q_shape, k_shape, v_shape}; +} + + +PD_BUILD_OP(flash_attn_bwd) + .Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"}) + .Outputs({"q_grad", "k_grad", "v_grad"}) + .Attrs({"dropout: float", "causal: bool"}) + .SetKernelFn(PD_KERNEL(SRFlashAttnBwd)) + .SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype)); diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 0b25ef3eac98..dc0ba9895027 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -77,6 +77,7 @@ def get_gencode_flags(): "./generation/step.cu", "./generation/quant_int8.cu", "./generation/dequant_int8.cu", + "./generation/flash_attn_bwd.cc", ], extra_compile_args={ "cxx": ["-O3"], diff --git a/llm/llama/run_trainer_tp2cp2.sh b/llm/llama/run_trainer_tp2cp2.sh new file mode 100644 index 000000000000..954d59c3100b --- /dev/null +++ b/llm/llama/run_trainer_tp2cp2.sh @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -x +unset CUDA_VISIBLE_DEVICES + +rm -rf log +rm -rf output + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +# export FLAGS_embedding_deterministic=1 +# export FLAGS_cudnn_deterministic=1 +# export FLAGS_flash_attn_version=v1 +# export USE_FAST_LN=0 + + +max_seq_length=1024 + +master=127.0.0.1 +port=36677 + +max_steps=10000 +log_dir=seq_${max_seq_length}_log +echo "log_dir:${log_dir}" +rm -rf $log_dir + +export PYTHONPATH=../../:$PYTHONPATH +python -u -m paddle.distributed.launch \ + --master $master:$port \ + --gpus "3,4,5,7" \ + --log_dir "./$log_dir" \ + run_pretrain.py \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "./output" \ + --split 949,50,1 \ + --max_seq_length $max_seq_length \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --per_device_eval_batch_size 4 \ + --bf16 \ + --fp16_opt_level "O2" \ + --use_flash_attention 1 \ + --virtual_pp_degree 1 \ + --pp_recompute_interval 1 \ + --learning_rate 0.00001 \ + --min_learning_rate 0.000001 \ + --max_steps $max_steps \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --eval_steps 1001 \ + --disable_tqdm true \ + --continue_training 0 \ + --do_train \ + --device "gpu" \ + --enable_linear_fused_grad_add false \ + --recompute_use_reentrant true \ + --data_cache "./data_cache" \ + --pipeline_parallel_degree 1 \ + --cp_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sequence_parallel false \ + --skip_profile_timer true \ + --amp_master_grad \ + --report_to "visualdl" \ + --logging_dir "./visualdl_log" \ + --save_steps 2000000 \ diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index e58888772a5d..cd9d91a22320 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -485,11 +485,15 @@ def main(): config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob config.sep_parallel_degree = training_args.sep_parallel_degree + config.cp_parallel_degree = training_args.cp_parallel_degree if config.sequence_parallel: assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." assert ( config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.cp_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by cp_parallel_degree {config.cp_parallel_degree}" if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..caa4cb14de53 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -81,6 +81,7 @@ from ..quantization.quantization_linear import QuantizationLinear except: QuantizationLinear = None +from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, _add_variant, @@ -763,6 +764,8 @@ def train( trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size if self.args.sep_parallel_degree > 0: trainable_numel = trainable_numel // self.args.sep_parallel_degree + if self.args.cp_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.cp_parallel_degree # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited # so, the trainable numel is a little bigger than real. logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") @@ -897,6 +900,8 @@ def _inner_training_loop( for step, inputs in enumerate(epoch_iterator): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) + if self.args.use_hybrid_parallel and self.args.cp_parallel_degree > 1: + inputs = split_inputs_sequence_dim_load_balance(inputs) self.timers and self.timers("read-data").stop() os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) @@ -1006,7 +1011,11 @@ def _inner_training_loop( assert reshard_util.is_sharding_opt(self.optimizer) self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) - if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): + if ( + self.optimizer._dp_enable + or getattr(self.optimizer, "_sep_enable", False) + or getattr(self.optimizer, "_cp_enable", False) + ): fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() @@ -1760,6 +1769,7 @@ def _wrap_model(self, model, training=True): in_sharding_parallel_mode = self.sharding is not None in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 in_sep_parallel_mode = self.args.sep_parallel_degree > 1 + in_cp_parallel_mode = self.args.cp_parallel_degree > 1 # Multi-gpu training if ( @@ -1770,6 +1780,7 @@ def _wrap_model(self, model, training=True): or in_sharding_parallel_mode or in_tensor_parallel_mode or in_sep_parallel_mode + or in_cp_parallel_mode ) ): model = paddle.DataParallel(model) @@ -1897,7 +1908,7 @@ def get_expected_keys(inputs, keys): if ( not in_pipeline_parallel_mode and not in_sharding_parallel_mode - and (in_tensor_parallel_mode or in_sep_parallel_mode) + and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) ): if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7b792ad34e7a..8293fc912e70 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -230,6 +230,10 @@ class TrainingArguments: The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. ) + cp_parallel_degree (`int`, *optional*, defaults to `-1`)( + The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to + data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. + ) data_parallel_config (`str`, *optional*)( Some additional configs which affect data parallel performance, we provide some option to config it. following config is support: @@ -583,6 +587,15 @@ class TrainingArguments: ) }, ) + cp_parallel_degree: int = field( + default=-1, + metadata={ + "help": ( + "The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to " + "data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. " + ) + }, + ) data_parallel_config: str = field( default="", metadata={ @@ -918,6 +931,7 @@ def __post_init__(self): if world_size > 1: tensor_parallel_degree = max(self.tensor_parallel_degree, 1) sep_parallel_degree = max(self.sep_parallel_degree, 1) + cp_parallel_degree = max(self.cp_parallel_degree, 1) pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( @@ -927,7 +941,7 @@ def __post_init__(self): if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // ( - tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + tensor_parallel_degree * sep_parallel_degree * cp_parallel_degree * pipeline_parallel_degree ) sharding_parallel_degree = max(self.sharding_parallel_degree, 1) @@ -936,7 +950,11 @@ def __post_init__(self): self.sharding = [] self.data_parallel_degree = world_size // ( - sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + sharding_parallel_degree + * tensor_parallel_degree + * sep_parallel_degree + * cp_parallel_degree + * pipeline_parallel_degree ) if ( @@ -944,12 +962,14 @@ def __post_init__(self): or tensor_parallel_degree > 1 or pipeline_parallel_degree > 1 or self.sep_parallel_degree > 1 + or self.cp_parallel_degree > 1 ): self.use_hybrid_parallel = True self.sharding_parallel_degree = sharding_parallel_degree self.tensor_parallel_degree = tensor_parallel_degree self.pipeline_parallel_degree = pipeline_parallel_degree self.sep_parallel_degree = sep_parallel_degree + self.cp_parallel_degree = cp_parallel_degree if not self.use_hybrid_parallel: self.sharding = [] @@ -957,6 +977,7 @@ def __post_init__(self): self.tensor_parallel_degree = -1 self.pipeline_parallel_degree = -1 self.sep_parallel_degree = -1 + self.cp_parallel_degree = -1 if self.hybrid_parallel_topo_order is None: self.hybrid_parallel_topo_order = "pp_first" @@ -1140,18 +1161,41 @@ def is_segment_parallel_supported(): logger.warning("segment parallel is not supported!!!, Ignore it.") return support_sep + def is_context_parallel_supported(): + import inspect + + members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)] + support_cp = "get_cp_parallel_world_size" in members + if not support_cp: + logger.warning("context parallel is not supported!!!, Ignore it.") + return support_cp + if self.hybrid_parallel_topo_order == "pp_first": - if is_segment_parallel_supported(): + if is_context_parallel_supported(): + order = ["dp", "pp", "sharding", "sep", "cp", "mp"] + elif is_segment_parallel_supported(): order = ["dp", "pp", "sharding", "sep", "mp"] else: order = ["dp", "pp", "sharding", "mp"] if self.hybrid_parallel_topo_order == "sharding_first": - if is_segment_parallel_supported(): + if is_context_parallel_supported(): + order = ["dp", "sharding", "pp", "sep", "cp", "mp"] + elif is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] - if is_segment_parallel_supported(): + if is_context_parallel_supported(): + hybrid_configs = { + "dp_degree": self.data_parallel_degree, + "mp_degree": self.tensor_parallel_degree, + "pp_degree": self.pipeline_parallel_degree, + "sharding_degree": self.sharding_parallel_degree, + "sep_degree": self.sep_parallel_degree, + "cp_degree": self.cp_parallel_degree, + "order": order, + } + elif is_segment_parallel_supported(): hybrid_configs = { "dp_degree": self.data_parallel_degree, "mp_degree": self.tensor_parallel_degree, @@ -1241,6 +1285,7 @@ def is_segment_parallel_supported(): elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) + self.cp_parallel_degree = max(self.cp_parallel_degree, 1) self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( @@ -1250,7 +1295,10 @@ def is_segment_parallel_supported(): if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // ( - self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree + self.tensor_parallel_degree + * self.sep_parallel_degree + * self.cp_parallel_degree + * self.pipeline_parallel_degree ) self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1) @@ -1262,6 +1310,7 @@ def is_segment_parallel_supported(): self.sharding_parallel_degree * self.tensor_parallel_degree * self.sep_parallel_degree + * self.cp_parallel_degree * self.pipeline_parallel_degree ) diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index 4bda24695a48..99957f2057e7 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -465,8 +465,9 @@ def __init__(self, **kwargs): # Parameters for tensor parallel self.tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", -1) self.tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0) - # Parameters for sep + # Parameters for sep and cp self.sep_parallel_degree = kwargs.pop("sep_parallel_degree", -1) + self.cp_parallel_degree = kwargs.pop("cp_parallel_degree", -1) # If set to True, this option is used with fleet.meta_parallel.ParallelCrossEntropy # to calculate cross-entropy loss for parallel model. self.tensor_parallel_output = kwargs.pop("tensor_parallel_output", False) diff --git a/paddlenlp/transformers/context_parallel_utils.py b/paddlenlp/transformers/context_parallel_utils.py new file mode 100644 index 000000000000..b89c020a4a94 --- /dev/null +++ b/paddlenlp/transformers/context_parallel_utils.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +from paddle.distributed.fleet import fleet + + +def split_inputs_sequence_dim_load_balance(inputs, rank=None, degree=None): + if degree is None and rank is None: + _hcg = fleet.get_hybrid_communicate_group() + degree = _hcg.get_cp_parallel_world_size() + rank = _hcg.get_cp_parallel_rank() + assert isinstance(degree, int) and isinstance( + rank, int + ), f"degree:{type(degree)} and rank:{type(rank)} must be int" + if degree <= 1: + return inputs + + def do_split_sequence_dim_load_balance(data, rank, degree): + if data is None: + return None + assert isinstance(data, paddle.Tensor), f"data should be paddle.Tensor, but is type:{type(data)}" + assert len(data.shape) == 2, f"data dims should be 2, but shaped: {data.shape}" + sliced_datas = paddle.split(data, num_or_sections=degree * 2, axis=-1) + sliced_data0, sliced_data1 = sliced_datas[rank], sliced_datas[degree * 2 - 1 - rank] + return paddle.concat([sliced_data0, sliced_data1], axis=-1) + + if isinstance(inputs, paddle.Tensor): + return do_split_sequence_dim_load_balance(inputs, rank, degree) + elif isinstance(inputs, dict): + res = {} + for k, tensor in inputs.items(): + res[k] = do_split_sequence_dim_load_balance(tensor, rank, degree) + elif isinstance(inputs, list): + res = [] + for tensor in inputs: + res.append(do_split_sequence_dim_load_balance(tensor, rank, degree)) + else: + raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}") + return res + + +def split_inputs_sequence_dim(inputs, rank=None, degree=None): + if degree is None and rank is None: + _hcg = fleet.get_hybrid_communicate_group() + degree = _hcg.get_sep_parallel_world_size() + rank = _hcg.get_sep_parallel_rank() + if degree == 1: + degree = _hcg.get_cp_parallel_world_size() + rank = _hcg.get_cp_parallel_rank() + assert isinstance(degree, int) and isinstance( + rank, int + ), f"degree:{type(degree)} and rank:{type(rank)} must be int" + if degree <= 1: + return inputs + + def do_split_sequence_dim(data, rank, degree): + if data is None: + return None + assert isinstance(data, paddle.Tensor), f"data should be paddle.Tensor, but is type:{type(data)}" + assert len(data.shape) == 2, f"data dims should be 2, but shaped: {data.shape}" + sliced_data = paddle.split(data, num_or_sections=degree, axis=-1)[rank] + return sliced_data + + if isinstance(inputs, paddle.Tensor): + return do_split_sequence_dim(inputs, rank, degree) + elif isinstance(inputs, dict): + res = {} + for k, tensor in inputs.items(): + res[k] = do_split_sequence_dim(tensor, rank, degree) + elif isinstance(inputs, list): + res = [] + for tensor in inputs: + res.append(do_split_sequence_dim(tensor, rank, degree)) + else: + raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}") + return res diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 6009a80911d5..5cff7002b494 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -51,14 +51,22 @@ def swiglu(x, y=None): except: flash_attention = None +from paddlenlp.transformers.ring_flash_attention import RingFlashAttention +from paddlenlp.transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance -def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb): +def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb, cp_parallel_degree=-1): if get_env_device() != "gcu": assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if cp_parallel_degree > 1: + assert get_env_device() == "gpu", "context parallel only support cuda device for now" + kv_seq_len *= cp_parallel_degree if get_env_device() != "gcu": cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + if cp_parallel_degree > 1: + cos = split_inputs_sequence_dim_load_balance(cos) + sin = split_inputs_sequence_dim_load_balance(sin) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] @@ -142,6 +150,8 @@ def fusion_flash_attention( if version != "0.0.0" and version <= "2.5.2": if alibi is not None: raise ValueError("Flash Attention doesn't support alibi") + if config.cp_parallel_degree > 1: + raise ValueError(f"Context parallel is not implemented in version {version}") attn_output, attn_weights = flash_attention( query_states, key_states, @@ -154,6 +164,8 @@ def fusion_flash_attention( alibi = alibi.reshape([bsz, num_heads, 1, -1]) attention_mask = attention_mask.cast(alibi.dtype) + alibi if get_env_device() == "npu": + if config.cp_parallel_degree > 1: + raise ValueError(f"Context parallel is not implemented for npu") attn_output = core.eager._run_custom_op( "flash_attention_npu", query_states, @@ -168,6 +180,8 @@ def fusion_flash_attention( npu_is_casual, )[0] elif get_env_device() == "gcu": + if config.cp_parallel_degree > 1: + raise ValueError(f"Context parallel is not implemented for gcu") attn_output = core.eager._run_custom_op( "fused_sdp_flash_attention_gcu", query_states, @@ -179,13 +193,22 @@ def fusion_flash_attention( True, )[0] else: - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=attention_mask is None, - ) + if config.cp_parallel_degree > 1: + attn_output = RingFlashAttention.apply( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=True, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) attn_weights = None if reshard_layer is not None: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 366f7ff3c083..18eda70d627d 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -99,6 +99,7 @@ def swiglu(x, y=None): ] + def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -233,6 +234,9 @@ def scaled_dot_product_attention( # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] else: + if config.cp_parallel_degree > 1: + raise ValueError("Context parallel requires `use_flash_attention=True`") + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] query_states = paddle.transpose(query_states, [0, 2, 1, 3]) # merge with the next tranpose @@ -765,7 +769,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): assert self.num_key_value_heads % config.sep_parallel_degree == 0 assert self.num_heads % config.sep_parallel_degree == 0 self.reshard_layer = ReshardLayer() - + self.context_parallel = config.cp_parallel_degree > 1 self.config = config def _init_rope(self): @@ -932,6 +936,17 @@ def forward( if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + if self.context_parallel: + batch_size, seq_length, _, _ = query_states.shape + group = fleet.get_hybrid_communicate_group().get_cp_parallel_group() + chunk_size = seq_length // 2 + chunk_num = group.nranks * 2 + rank = group.rank + first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64") + second_chunk_ids = paddle.arange( + (chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64" + ) + position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length)) if self.use_fused_rope: query_states, key_states = fusion_ops.fusion_rope( query_states, @@ -941,9 +956,12 @@ def forward( position_ids, past_key_value, self.rotary_emb, + self.cp_parallel_degree ) else: + if self.context_parallel: + kv_seq_len *= self.config.cp_parallel_degree if self.config.use_long_sequence_strategies: cos, sin = self.rotary_emb(seq_len=kv_seq_len) cos = cos[None, :, None, :] @@ -1512,6 +1530,8 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) + if self.config.cp_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): + raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") # embed positions if attention_mask is None: # [bs, seq_len] @@ -1657,7 +1677,10 @@ def forward(self, prediction_scores, masked_lm_labels): if self.config.sep_parallel_degree > 1: _hcg = fleet.get_hybrid_communicate_group() - masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + if self.config.cp_parallel_degree > 1: + _hcg = fleet.get_hybrid_communicate_group() + masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_cp_parallel_group()) # skip ignore_index which loss == 0 # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] # loss = paddle.mean(masked_lm_loss) @@ -1673,7 +1696,7 @@ def forward(self, prediction_scores, masked_lm_labels): return loss -class ConcatSePMaskedLoss(PyLayer): +class ConcatMaskedLoss(PyLayer): @staticmethod def forward(ctx, inp, axis, group): inputs = [] @@ -1728,6 +1751,9 @@ def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sep_parallel_degree > 1: assert seq_length % self.config.sep_parallel_degree == 0 seq_length = seq_length // self.config.sep_parallel_degree + if self.config.cp_parallel_degree > 1: + assert seq_length % self.config.cp_parallel_degree == 0 + seq_length = seq_length // self.config.cp_parallel_degree hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) if tensor_parallel_output is None: diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py new file mode 100644 index 000000000000..b8cadaf87e48 --- /dev/null +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -0,0 +1,436 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# paddlenlp/transformers/ring_attention.py + +import random + +import numpy as np +import paddle +import paddle.distributed as dist +from custom_setup_ops import flash_attn_bwd +from paddle import _C_ops +from paddle.autograd.py_layer import PyLayer +from paddle.nn.functional.flash_attention import scaled_dot_product_attention + + +class RingCommunicator: + def __init__(self, group, local_key, local_value): + self._k_buffer = [paddle.zeros_like(local_key) for _ in range(2)] + self._v_buffer = [paddle.zeros_like(local_value) for _ in range(2)] + + self._k_buffer[0] = local_key.clone() + self._v_buffer[0] = local_value.clone() + + self._next_buffer_idx = 0 + + self.group = group + self.group_rank = group.rank + self.send_rank = self.group.ranks[(self.group_rank + 1) % self.group.world_size] + self.recv_rank = self.group.ranks[(self.group_rank - 1) % self.group.world_size] + + self._reqs = [] + + def wait(self): + # for req in self._reqs: + # req.wait() + # self._reqs = None + paddle.device.synchronize() + + def add_to_buffers(self, key, value): + if key.shape != self._k_buffer[self._next_buffer_idx].shape: + self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += key + self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += value + else: + self._k_buffer[self._next_buffer_idx] += key + self._v_buffer[self._next_buffer_idx] += value + + def get_buffers(self): + return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx] + + def send_recv(self): + send_k_op = dist.P2POp(dist.isend, self._k_buffer[self._next_buffer_idx], self.send_rank, self.group) + send_v_op = dist.P2POp(dist.isend, self._v_buffer[self._next_buffer_idx], self.send_rank, self.group) + recv_k_op = dist.P2POp(dist.irecv, self._k_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + recv_v_op = dist.P2POp(dist.irecv, self._v_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + + self._next_buffer_idx = (self._next_buffer_idx + 1) % 2 + + ops = [send_k_op, send_v_op, recv_k_op, recv_v_op] + + self._reqs = dist.batch_isend_irecv(ops) + + +def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): + if second_chunk_only: + second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :] + second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :] + second_chunk_out, second_chunk_lse = update_out_and_lse( + second_chunk_out, second_chunk_lse, block_out, block_lse + ) + old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out + old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse + return old_out, old_lse + else: + lse = paddle.log(1 + paddle.exp(block_lse - old_lse)) + old_lse + return old_out * paddle.exp(old_lse - lse) + block_out * paddle.exp(block_lse - lse), lse + + +def get_chunk_id(rank, cp_size): + return rank, (2 * cp_size - 1 - rank) + + +def concat_masks(attn_masks_list, rank, cp_size): + assert len(attn_masks_list) == 2 * cp_size + first_chunk_id, second_chunk_id = get_chunk_id(rank, cp_size) + return paddle.concat([attn_masks_list[first_chunk_id], attn_masks_list[second_chunk_id]], axis=3) + + +def balanced_ring_flash_attention_fwd_func( + group, + local_query, + local_key, + local_value, + fixed_seed_offset=None, + attn_mask=None, + dropout=0.0, + is_causal=False, + training=True, +): + cp_size = group.world_size + rank = group.rank + + comm_buffer = RingCommunicator(group, local_key, local_value) + local_q_seq_len = local_query.shape[1] + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + if is_causal: + local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + for step in range(cp_size): + block_k, block_v = comm_buffer.get_buffers() + + if step != cp_size - 1: + comm_buffer.send_recv() + + if not is_causal: + # out [bs, seq, nhead, headdim] + # lse [bs, nhead, seq] + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + dropout, + False, + False, + not training, + "", + ) + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + + if step == 0: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + if step == 0: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, "" + ) + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + out, lse = block_out, block_lse + elif step > rank: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query_second_chunk, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "", + ) + block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) + else: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k[:, : local_q_seq_len // 2, :, :], + block_v[:, : local_q_seq_len // 2, :, :], + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "", + ) + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + # if step != cp_size - 1: + # comm_buffer.wait() + paddle.device.synchronize() + + out = out.to(local_query.dtype) + lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) + return out, lse + + +def balanced_ring_flash_attention_bwd_func( + group, + out_grad, + local_query, + local_key, + local_value, + local_out, + lse, + fixed_seed_offset, + attn_mask, + dropout=0.0, + is_causal=False, +): + cp_size = group.world_size + rank = group.rank + + local_q_seq_len = local_query.shape[1] + + query_grad_buffer = paddle.zeros_like(local_query).to("float32") + key_grad_buffer = paddle.zeros_like(local_key).to("float32") + value_grad_buffer = paddle.zeros_like(local_value).to("float32") + + kv_comm_buffer = RingCommunicator(group, local_key, local_value) + grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) + + if is_causal: + local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + lse_second_chunk = lse[:, :, local_q_seq_len // 2 :].clone().contiguous() + out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + + for step in range(cp_size): + block_k, block_v = kv_comm_buffer.get_buffers() + + if step != cp_size - 1: + kv_comm_buffer.send_recv() + + if not is_causal: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k, + block_v, + local_out, + lse, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + out_grad, + dropout, + False, + ) + query_grad_buffer += block_q_grad + else: + if step == 0: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True + ) + query_grad_buffer += block_q_grad + elif step > rank: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query_second_chunk, + block_k, + block_v, + local_out_second_chunk, + lse_second_chunk, + fixed_seed_offset, + None, + out_grad_second_chunk, + dropout, + False, + ) + query_grad_buffer[:, local_q_seq_len // 2 :, :, :] += block_q_grad + else: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k[:, : local_q_seq_len // 2, :, :], + block_v[:, : local_q_seq_len // 2, :, :], + local_out, + lse, + fixed_seed_offset, + None, + out_grad, + dropout, + False, + ) + query_grad_buffer += block_q_grad + + # if step != cp_size - 1: + # kv_comm_buffer.wait() + # if step != 0: + # grad_comm_buffer.wait() + paddle.device.synchronize() + + grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) + grad_comm_buffer.send_recv() + + grad_comm_buffer.wait() + key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() + + dtype = local_query.dtype + return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) + + +class RingFlashAttention(PyLayer): + @staticmethod + def forward( + ctx, + query, + key, + value, + group=None, + fixed_seed_offset=None, + attn_mask=None, + dropout=0.0, + is_causal=False, + training=True, + ): + if dropout > 0.0: + raise NotImplementedError("Dropout is not supported in ring attention yet.") + if group is None: + group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() + if attn_mask is not None: + is_causal = False + + out, lse = balanced_ring_flash_attention_fwd_func( + group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training + ) + ctx.save_for_backward(query, key, value, out, lse, attn_mask) + ctx.group = group + ctx.fixed_seed_offset = fixed_seed_offset + ctx.dropout = dropout + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, out_grad): + query, key, value, out, lse, attn_mask = ctx.saved_tensor() + group = ctx.group + fixed_seed_offset = ctx.fixed_seed_offset + dropout = ctx.dropout + is_causal = ctx.is_causal + + if fixed_seed_offset is None: + fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() + + query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func( + group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal + ) + if attn_mask is not None and not attn_mask.stop_gradient: + return query_grad, key_grad, value_grad, None + else: + return query_grad, key_grad, value_grad + + +import unittest + + +class TestRingFlashAttention(unittest.TestCase): + def setUp(self): + paddle.distributed.init_parallel_env() + self.group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") + self.degree = self.group.world_size + self.rank = self.group.rank + + seed = 42 + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + def generate_full_data(self, batch_size, seq_len, num_head, head_dim): + query = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + key = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + value = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + return query, key, value + + def split_belanced_data(self, input): + sliced_datas = paddle.split(input, num_or_sections=self.degree * 2, axis=1) + sliced_data0, sliced_data1 = sliced_datas[self.rank], sliced_datas[self.degree * 2 - 1 - self.rank] + return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() + + def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask): + query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim) + + local_query = self.split_belanced_data(query) + local_key = self.split_belanced_data(key) + local_value = self.split_belanced_data(value) + + local_query.stop_gradient = False + local_key.stop_gradient = False + local_value.stop_gradient = False + + if use_mask: + mask_shape = (1, 1, query.shape[1], query.shape[1]) + mask = np.random.random(mask_shape) + attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) + attn_mask = paddle.ones(mask_shape).to(query.dtype) + attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=self.degree * 2) + first_chunk_id, second_chunk_id = get_chunk_id(self.rank, self.degree) + local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) + else: + attn_mask = None + local_attn_mask = None + + local_out = RingFlashAttention.apply( + local_query, local_key, local_value, self.group, is_causal=is_causal, attn_mask=local_attn_mask + ) + ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask) + ref_local_out = self.split_belanced_data(ref_out) + np.testing.assert_allclose(local_out.numpy(), ref_local_out.numpy(), rtol=5e-03, atol=1e-03) + + local_out.backward() + ref_out.backward() + + ref_local_query_grad = self.split_belanced_data(query.grad) + ref_local_key_grad = self.split_belanced_data(key.grad) + ref_local_value_grad = self.split_belanced_data(value.grad) + + np.testing.assert_allclose(local_query.grad.numpy(), ref_local_query_grad.numpy(), rtol=5e-03, atol=1e-03) + np.testing.assert_allclose(local_key.grad.numpy(), ref_local_key_grad.numpy(), rtol=5e-03, atol=1e-03) + np.testing.assert_allclose(local_value.grad.numpy(), ref_local_value_grad.numpy(), rtol=5e-03, atol=1e-03) + + def test_normal_flash_attention(self): + self.single_test(1, 256, 1, 256, False, False) + + def test_masked_flash_attention(self): + self.single_test(1, 256, 1, 256, False, True) + + def test_casual_flash_attention(self): + self.single_test(1, 256, 1, 256, True, False) + + +if __name__ == "__main__": + unittest.main() +# python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 ring_flash_attention.py diff --git a/paddlenlp/transformers/ring_flash_attention_back_up.py b/paddlenlp/transformers/ring_flash_attention_back_up.py new file mode 100644 index 000000000000..5d6266967887 --- /dev/null +++ b/paddlenlp/transformers/ring_flash_attention_back_up.py @@ -0,0 +1,664 @@ +# paddlenlp/transformers/ring_attention.py + +import paddle +import paddle.distributed as dist +from paddle import _C_ops +from paddle.nn.functional.flash_attention import scaled_dot_product_attention +from paddle.autograd.py_layer import PyLayer +from custom_setup_ops import flash_attn_bwd +from paddle.framework import core + +import random +import numpy as np + +class RingCommunicator: + def __init__(self, group, local_key, local_value): + self._k_buffer = [paddle.zeros_like(local_key) for _ in range(2)] + self._v_buffer = [paddle.zeros_like(local_value) for _ in range(2)] + + self._k_buffer[0] = local_key.clone() + self._v_buffer[0] = local_value.clone() + + self._next_buffer_idx = 0 + + self.group = group + self.group_rank = group.rank + self.send_rank = self.group.ranks[(self.group_rank + 1) % self.group.world_size] + self.recv_rank = self.group.ranks[(self.group_rank - 1) % self.group.world_size] + + self._reqs = [] + + def wait(self): + # for req in self._reqs: + # req.wait() + # self._reqs = None + paddle.device.synchronize() + + def add_to_buffers(self, key, value): + if key.shape != self._k_buffer[self._next_buffer_idx].shape: + self._k_buffer[self._next_buffer_idx][:, :key.shape[1], :, :] += key + self._v_buffer[self._next_buffer_idx][:, :key.shape[1], :, :] += value + else: + self._k_buffer[self._next_buffer_idx] += key + self._v_buffer[self._next_buffer_idx] += value + + def get_buffers(self): + return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx] + + def send_recv(self): + send_k_op = dist.P2POp(dist.isend, self._k_buffer[self._next_buffer_idx], self.send_rank, self.group) + send_v_op = dist.P2POp(dist.isend, self._v_buffer[self._next_buffer_idx], self.send_rank, self.group) + recv_k_op = dist.P2POp(dist.irecv, self._k_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + recv_v_op = dist.P2POp(dist.irecv, self._v_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + + self._next_buffer_idx = (self._next_buffer_idx + 1) % 2 + + ops = [send_k_op, send_v_op, recv_k_op, recv_v_op] + + self._reqs = dist.batch_isend_irecv(ops) + + +def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): + if second_chunk_only: + second_chunk_out = old_out[:,old_out.shape[1]//2:, :, :] + second_chunk_lse = old_lse[:,old_lse.shape[1]//2:, :, :] + second_chunk_out, second_chunk_lse = update_out_and_lse(second_chunk_out, second_chunk_lse, block_out, block_lse) + old_out[:,old_out.shape[1]//2:, :, :] = second_chunk_out + old_lse[:,old_lse.shape[1]//2:, :, :] = second_chunk_lse + return old_out, old_lse + else: + lse = paddle.log(1 + paddle.exp(block_lse - old_lse)) + old_lse + return old_out * paddle.exp(old_lse - lse) + block_out * paddle.exp(block_lse - lse), lse + + +def get_chunk_id(rank, cp_size): + return rank, (2 * cp_size - 1 - rank) + + +def concat_masks(attn_masks_list, rank, cp_size): + assert len(attn_masks_list) == 2 * cp_size + first_chunk_id, second_chunk_id = get_chunk_id(rank, cp_size) + return paddle.concat([attn_masks_list[first_chunk_id], attn_masks_list[second_chunk_id]], axis=3) + + +def balanced_ring_flash_attention_fwd_funcV2(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): + cp_size = group.world_size + rank = group.rank + + comm_buffer = RingCommunicator(group, local_key, local_value) + local_q_seq_len = local_query.shape[1] + + computation_streams = [paddle.device.Stream(), paddle.device.current_stream()] + update_out_and_lse_done = paddle.device.Event() + block_out_buffer = [paddle.zeros_like(local_query) for _ in range(2)] + block_lse_buffer = [paddle.zeros([local_query.shape[0], local_query.shape[2], local_query.shape[1]], dtype=local_query.dtype) for _ in range(2)] + paddle.device.synchronize() + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + if is_causal: + local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() + for step in range(cp_size + 1): + block_k, block_v = comm_buffer.get_buffers() + + if step != cp_size - 1: + comm_buffer.send_recv() + + if not is_causal: + # out [bs, seq, nhead, headdim] + # lse [bs, nhead, seq] + if step < cp_size: + with paddle.device.stream_guard(computation_streams[step % 2]): + block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + dropout, + False, + False, + not training, + "" + ) + if step > 0: + if step > 1: + computation_streams[(step - 1)% 2].wait_event(update_out_and_lse_done) + with paddle.device.stream_guard(computation_streams[(step - 1)% 2]): + core.nvprof_nvtx_push(f"update_out_and_lse step {step}") + block_out = block_out_buffer[(step - 1) % 2] + block_lse = block_lse_buffer[(step - 1) % 2][:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + if step - 1 == 0: + out, lse= block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + core.nvprof_nvtx_pop() + if step < cp_size: + computation_streams[(step - 1)% 2].record_event(update_out_and_lse_done) + else: + if step < cp_size: + with paddle.device.stream_guard(computation_streams[step % 2]): + if step == 0: + block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + True, + False, + not training, + "") + elif step > rank: + block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( + local_query_second_chunk, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "" + ) + else: + block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( + local_query, + block_k[:, :local_q_seq_len // 2, :, :], + block_v[:, :local_q_seq_len // 2, :, :], + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "" + ) + if step > 0: + if step > 1: + computation_streams[(step - 1)% 2].wait_event(update_out_and_lse_done) + with paddle.device.stream_guard(computation_streams[(step - 1)% 2]): + core.nvprof_nvtx_push(f"update_out_and_lse step {step}") + block_out = block_out_buffer[(step - 1) % 2] + block_lse = block_lse_buffer[(step - 1) % 2][:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + if step - 1 == 0: + out, lse= block_out, block_lse + elif step - 1 > rank: + block_lse = block_lse[:, :(local_q_seq_len//2), :, :] + out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + core.nvprof_nvtx_pop() + if step < cp_size: + computation_streams[(step - 1)% 2].record_event(update_out_and_lse_done) + # if step != cp_size - 1: + # comm_buffer.wait() + paddle.device.synchronize() + + out = out.to(local_query.dtype) + lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) + return out, lse + +def balanced_ring_flash_attention_fwd_func(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): + cp_size = group.world_size + rank = group.rank + + comm_buffer = RingCommunicator(group, local_key, local_value) + local_q_seq_len = local_query.shape[1] + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + if is_causal: + local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() + for step in range(cp_size): + block_k, block_v = comm_buffer.get_buffers() + + if step != cp_size - 1: + comm_buffer.send_recv() + + if not is_causal: + # out [bs, seq, nhead, headdim] + # lse [bs, nhead, seq] + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + dropout, + False, + False, + not training, + "") + block_lse = block_lse[:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + + if step == 0: + out, lse= block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + if step == 0: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + True, + False, + not training, + "") + block_lse = block_lse[:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + out, lse= block_out, block_lse + elif step > rank: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query_second_chunk, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "" + ) + block_lse = block_lse[:, :, 0:(local_q_seq_len//2)] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + core.nvprof_nvtx_push("update_out_and_lse") + out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) + core.nvprof_nvtx_pop() + else: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k[:, :local_q_seq_len // 2, :, :], + block_v[:, :local_q_seq_len // 2, :, :], + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "" + ) + block_lse = block_lse[:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + core.nvprof_nvtx_push("update_out_and_lse") + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + core.nvprof_nvtx_pop() + + if step != cp_size - 1: + comm_buffer.wait() + + out = out.to(local_query.dtype) + lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) + return out, lse + + +def balanced_ring_flash_attention_bwd_func(group, out_grad, local_query, local_key, local_value, local_out, lse, fixed_seed_offset, attn_mask, dropout=0.0, is_causal=False): + cp_size = group.world_size + rank = group.rank + + local_q_seq_len = local_query.shape[1] + + query_grad_buffer = paddle.zeros_like(local_query).to("float32") + key_grad_buffer = paddle.zeros_like(local_key).to("float32") + value_grad_buffer = paddle.zeros_like(local_value).to("float32") + + kv_comm_buffer = RingCommunicator(group, local_key, local_value) + grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) + + if is_causal: + local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() + local_out_second_chunk = local_out[:, local_q_seq_len // 2:, :, :].clone().contiguous() + lse_second_chunk = lse[:, :, local_q_seq_len // 2:].clone().contiguous() + out_grad_second_chunk = out_grad[:, local_q_seq_len // 2:, :, :].clone().contiguous() + + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + + for step in range(cp_size): + block_k, block_v = kv_comm_buffer.get_buffers() + + if step != cp_size - 1: + kv_comm_buffer.send_recv() + + if not is_causal: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k, + block_v, + local_out, + lse, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + out_grad, + dropout, + False) + query_grad_buffer += block_q_grad + else: + if step == 0: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k, + block_v, + local_out, + lse, + fixed_seed_offset, + None, + out_grad, + dropout, + True) + query_grad_buffer += block_q_grad + elif step > rank: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query_second_chunk, + block_k, + block_v, + local_out_second_chunk, + lse_second_chunk, + fixed_seed_offset, + None, + out_grad_second_chunk, + dropout, + False) + query_grad_buffer[:, local_q_seq_len // 2:, :, :] += block_q_grad + else: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k[:, :local_q_seq_len // 2, :, :], + block_v[:, :local_q_seq_len // 2, :, :], + local_out, + lse, + fixed_seed_offset, + None, + out_grad, + dropout, + False) + query_grad_buffer += block_q_grad + + # if step != cp_size - 1: + # kv_comm_buffer.wait() + # if step != 0: + # grad_comm_buffer.wait() + paddle.device.synchronize() + + grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) + grad_comm_buffer.send_recv() + + grad_comm_buffer.wait() + key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() + + dtype = local_query.dtype + return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) + +def ring_flash_attention_fwd_func(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): + cp_size = group.world_size + + comm_buffer = RingCommunicator(group, local_key, local_value) + local_q_seq_len = local_query.shape[1] + + if attn_mask is not None: + cur_attn_masks = paddle.split(attn_mask, num_or_sections=cp_size, axis=3) + + for step in range(cp_size): + block_k, block_v = comm_buffer.get_buffers() + + if step != cp_size - 1: + comm_buffer.send_recv() + + if not is_causal or step <= group.rank: + _causal = is_causal and 0 == step + # out [bs, seq, nhead, headdim] + # lse [bs, nhead, seq] + + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None if attn_mask is None else cur_attn_masks[(group.rank - step) % cp_size], + dropout, + _causal, + False, + not training, + "") + block_lse = block_lse[:, :, 0:local_q_seq_len] + block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + + if step == 0: + out, lse= block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step != cp_size - 1: + comm_buffer.wait() + + out = out.to(local_query.dtype) + lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) + return out, lse + +def ring_flash_attention_bwd_func(group, out_grad, local_query, local_key, local_value, local_out, lse, fixed_seed_offset, attn_mask, dropout=0.0, is_causal=False): + cp_size = group.world_size + + query_grad_buffer = paddle.zeros_like(local_query).to("float32") + key_grad_buffer = paddle.zeros_like(local_key).to("float32") + value_grad_buffer = paddle.zeros_like(local_value).to("float32") + + kv_comm_buffer = RingCommunicator(group, local_key, local_value) + grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) + + if attn_mask is not None: + cur_attn_masks = paddle.split(attn_mask, num_or_sections=cp_size, axis=3) + + for step in range(cp_size): + block_k, block_v = kv_comm_buffer.get_buffers() + + if step != cp_size - 1: + kv_comm_buffer.send_recv() + + if step <= group.rank or not is_causal: + _causal = is_causal and step == 0 + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k, + block_v, + local_out, + lse, + fixed_seed_offset, + None if attn_mask is None else cur_attn_masks[(group.rank - step) % cp_size], + out_grad, + dropout, + _causal) + query_grad_buffer += block_q_grad + grad_comm_buffer.wait() + grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) + elif step != 0: + grad_comm_buffer.wait() + + grad_comm_buffer.send_recv() + + if step != cp_size - 1: + kv_comm_buffer.wait() + + grad_comm_buffer.wait() + key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() + + dtype = local_query.dtype + return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) + + +class RingFlashAttention(PyLayer): + @staticmethod + def forward(ctx, query, key, value, group=None, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): + if dropout > 0.0: + raise NotImplementedError("Dropout is not supported in ring attention yet.") + if group is None: + group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() + if attn_mask is not None: + is_causal = False + + out, lse = balanced_ring_flash_attention_fwd_func(group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training) + ctx.save_for_backward(query, key, value, out, lse, attn_mask) + ctx.group = group + ctx.fixed_seed_offset = fixed_seed_offset + ctx.dropout = dropout + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, out_grad): + query, key, value, out, lse, attn_mask = ctx.saved_tensor() + group = ctx.group + fixed_seed_offset = ctx.fixed_seed_offset + dropout = ctx.dropout + is_causal = ctx.is_causal + + if fixed_seed_offset is None: + fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() + + query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal) + if attn_mask is not None and not attn_mask.stop_gradient: + return query_grad, key_grad, value_grad, None + else: + return query_grad, key_grad, value_grad + + +class BaselineRingFlashAttention(PyLayer): + @staticmethod + def forward(ctx, query, key, value, group=None, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): + if dropout > 0.0: + raise NotImplementedError("Dropout is not supported in ring attention yet.") + if group is None: + group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() + + if attn_mask is not None: + is_causal = False + + out, lse = ring_flash_attention_fwd_func(group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training) + + ctx.save_for_backward(query, key, value, out, lse, attn_mask) + ctx.group = group + ctx.fixed_seed_offset = fixed_seed_offset + ctx.dropout = dropout + ctx.is_causal = is_causal + + return out + + @staticmethod + def backward(ctx, out_grad): + query, key, value, out, lse, attn_mask = ctx.saved_tensor() + group = ctx.group + fixed_seed_offset = ctx.fixed_seed_offset + dropout = ctx.dropout + is_causal = ctx.is_causal + + if fixed_seed_offset is None: + fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() + + query_grad, key_grad, value_grad = ring_flash_attention_bwd_func(group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal) + if attn_mask is not None and not attn_mask.stop_gradient: + return query_grad, key_grad, value_grad, None + else: + return query_grad, key_grad, value_grad + + +import unittest +import time + +def generate_full_data(batch_size, seq_len, num_head, head_dim): + query = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + key = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + value = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + return query, key, value + +def split_belanced_data(input, rank, degree): + sliced_datas = paddle.split(input, num_or_sections=degree * 2, axis=1) + sliced_data0, sliced_data1 = sliced_datas[rank], sliced_datas[degree * 2 - 1 - rank] + return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() + +def test_new(): + paddle.distributed.init_parallel_env() + group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") + degree = group.world_size + rank = group.rank + + seed=42 + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + # query, key, value = generate_full_data(8, 512 * 2 * degree, 36, 256) + query, key, value = generate_full_data(1, 128 * 2 * degree, 1, 256) + is_causal = True + + local_query = split_belanced_data(query, rank, degree) + local_key = split_belanced_data(key, rank, degree) + local_value = split_belanced_data(value, rank, degree) + + local_query.stop_gradient = False + local_key.stop_gradient = False + local_value.stop_gradient = False + + mask_shape = (1, 1, query.shape[1], query.shape[1]) + mask = np.random.random(mask_shape) + attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) + attn_mask = paddle.ones(mask_shape).to(query.dtype) + attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=degree * 2) + first_chunk_id, second_chunk_id = get_chunk_id(rank, degree) + local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) + + local_out = RingFlashAttention.apply(local_query, local_key, local_value, group, is_causal=is_causal) + ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal) + ref_local_out = split_belanced_data(ref_out, rank, degree) + + np.testing.assert_allclose(local_out.numpy(), ref_local_out.numpy(), rtol=5e-03, atol=1e-03) + + local_out.backward() + ref_out.backward() + + ref_local_query_grad = split_belanced_data(query.grad, rank, degree) + ref_local_key_grad = split_belanced_data(key.grad, rank, degree) + ref_local_value_grad = split_belanced_data(value.grad, rank, degree) + + np.testing.assert_allclose(local_query.grad.numpy(), ref_local_query_grad.numpy(), rtol=5e-03, atol=1e-03) + np.testing.assert_allclose(local_key.grad.numpy(), ref_local_key_grad.numpy(), rtol=5e-03, atol=1e-03) + np.testing.assert_allclose(local_value.grad.numpy(), ref_local_value_grad.numpy(), rtol=5e-03, atol=1e-03) + + # return + + # epoch = 1000 + # start_time = time.time() + # for iter_id in range(0, epoch): + # if iter_id == 10: + # core.nvprof_start() + # if iter_id == 15: + # core.nvprof_stop() + # return + # core.nvprof_nvtx_push(f"Forward {iter_id}") + # temp_out = RingFlashAttention.apply(local_query, local_key, local_value, group, is_causal=is_causal) + # core.nvprof_nvtx_pop() + # core.nvprof_nvtx_push(f"Backward {iter_id}") + # temp_out.backward() + # core.nvprof_nvtx_pop() + + # end_time = time.time() + # execution_time = end_time - start_time + # print(f"RingFlashAttention执行时间: {execution_time}秒") + +if __name__ == "__main__": + # unittest.main() + test_new() From ab562b7ca76afe7ab7fb581eed417872967c6c2f Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 31 May 2024 12:15:12 +0800 Subject: [PATCH 2/7] update, using sep_group --- paddlenlp/transformers/llama/fusion_ops.py | 20 +- paddlenlp/transformers/llama/modeling.py | 4 +- .../transformers/ring_flash_attention.py | 234 +++--- .../ring_flash_attention_back_up.py | 664 ------------------ scripts/regression/ci_case.sh | 11 + scripts/regression/run_ci.sh | 2 +- .../transformers/test_ring_flash_attention.py | 124 ++++ 7 files changed, 242 insertions(+), 817 deletions(-) delete mode 100644 paddlenlp/transformers/ring_flash_attention_back_up.py create mode 100644 tests/transformers/test_ring_flash_attention.py diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 5cff7002b494..24e48342cfa0 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -52,9 +52,18 @@ def swiglu(x, y=None): flash_attention = None from paddlenlp.transformers.ring_flash_attention import RingFlashAttention -from paddlenlp.transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance -def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb, cp_parallel_degree=-1): + +def fusion_rope( + query_states, + key_states, + value_states, + hidden_states, + position_ids, + past_key_value, + rotary_emb, + cp_parallel_degree=-1, +): if get_env_device() != "gcu": assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape @@ -64,9 +73,6 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_ kv_seq_len *= cp_parallel_degree if get_env_device() != "gcu": cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) - if cp_parallel_degree > 1: - cos = split_inputs_sequence_dim_load_balance(cos) - sin = split_inputs_sequence_dim_load_balance(sin) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] @@ -165,7 +171,7 @@ def fusion_flash_attention( attention_mask = attention_mask.cast(alibi.dtype) + alibi if get_env_device() == "npu": if config.cp_parallel_degree > 1: - raise ValueError(f"Context parallel is not implemented for npu") + raise ValueError("Context parallel is not implemented for npu") attn_output = core.eager._run_custom_op( "flash_attention_npu", query_states, @@ -181,7 +187,7 @@ def fusion_flash_attention( )[0] elif get_env_device() == "gcu": if config.cp_parallel_degree > 1: - raise ValueError(f"Context parallel is not implemented for gcu") + raise ValueError("Context parallel is not implemented for gcu") attn_output = core.eager._run_custom_op( "fused_sdp_flash_attention_gcu", query_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 18eda70d627d..c526bce17f17 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -99,7 +99,6 @@ def swiglu(x, y=None): ] - def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -956,7 +955,7 @@ def forward( position_ids, past_key_value, self.rotary_emb, - self.cp_parallel_degree + self.config.cp_parallel_degree, ) else: @@ -972,7 +971,6 @@ def forward( ) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index b8cadaf87e48..60fbf441d5eb 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -14,15 +14,22 @@ # paddlenlp/transformers/ring_attention.py -import random - -import numpy as np import paddle import paddle.distributed as dist -from custom_setup_ops import flash_attn_bwd +import paddle.nn.functional as F from paddle import _C_ops from paddle.autograd.py_layer import PyLayer -from paddle.nn.functional.flash_attention import scaled_dot_product_attention + +try: + from paddlenlp_ops import flash_attn_bwd +except (ImportError, ModuleNotFoundError): + from paddlenlp.utils.log import logger + + logger.warning( + "if you run ring_flash_attention.py, please ensure you install " + "the paddlenlp_ops by following the instructions " + "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) class RingCommunicator: @@ -43,15 +50,19 @@ def __init__(self, group, local_key, local_value): self._reqs = [] def wait(self): - # for req in self._reqs: - # req.wait() - # self._reqs = None + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 paddle.device.synchronize() def add_to_buffers(self, key, value): if key.shape != self._k_buffer[self._next_buffer_idx].shape: - self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += key - self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += value + k_buffer_chunk = paddle.slice( + self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]] + ) + v_buffer_chunk = paddle.slice( + self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]] + ) + k_buffer_chunk += key + v_buffer_chunk += value else: self._k_buffer[self._next_buffer_idx] += key self._v_buffer[self._next_buffer_idx] += value @@ -73,18 +84,23 @@ def send_recv(self): def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): + if old_out is None and old_lse is None: + return block_out.to("float32"), block_lse.to("float32") + if second_chunk_only: - second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :] - second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :] + second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]]) + second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]]) second_chunk_out, second_chunk_lse = update_out_and_lse( - second_chunk_out, second_chunk_lse, block_out, block_lse + second_chunk_out_, second_chunk_lse_, block_out, block_lse ) - old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out - old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse + paddle.assign(second_chunk_out, second_chunk_out_) + paddle.assign(second_chunk_lse, second_chunk_lse_) return old_out, old_lse else: - lse = paddle.log(1 + paddle.exp(block_lse - old_lse)) + old_lse - return old_out * paddle.exp(old_lse - lse) + block_out * paddle.exp(block_lse - lse), lse + block_out, block_lse = block_out.to("float32"), block_lse.to("float32") + with paddle.amp.auto_cast(enable=False, dtype="bfloat16"): + lse = old_lse - F.log_sigmoid(old_lse - block_lse) + return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse def get_chunk_id(rank, cp_size): @@ -114,10 +130,14 @@ def balanced_ring_flash_attention_fwd_func( comm_buffer = RingCommunicator(group, local_key, local_value) local_q_seq_len = local_query.shape[1] + out, lse, k_cache, v_cache = None, None, dict(), dict() + if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) if is_causal: - local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + local_query_second_chunk = paddle.slice( + local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) for step in range(cp_size): block_k, block_v = comm_buffer.get_buffers() @@ -139,19 +159,16 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - - if step == 0: - out, lse = block_out, block_lse - else: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) else: + # block_k and block_v is from rank (group.rank - step) % cp_size if step == 0: block_out, _, block_lse, _ = _C_ops.flash_attn( local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, "" ) - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - out, lse = block_out, block_lse + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) elif step > rank: block_out, _, block_lse, _ = _C_ops.flash_attn( local_query_second_chunk, @@ -165,14 +182,16 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) else: + block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) + block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) block_out, _, block_lse, _ = _C_ops.flash_attn( local_query, - block_k[:, : local_q_seq_len // 2, :, :], - block_v[:, : local_q_seq_len // 2, :, :], + block_k, + block_v, fixed_seed_offset, None, dropout, @@ -181,20 +200,23 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) out, lse = update_out_and_lse(out, lse, block_out, block_lse) + k_cache[step] = block_k + v_cache[step] = block_v - # if step != cp_size - 1: - # comm_buffer.wait() + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 paddle.device.synchronize() out = out.to(local_query.dtype) - lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) - return out, lse + lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1]) + return out, lse, k_cache, v_cache def balanced_ring_flash_attention_bwd_func( group, + k_cache, + v_cache, out_grad, local_query, local_key, @@ -208,21 +230,27 @@ def balanced_ring_flash_attention_bwd_func( ): cp_size = group.world_size rank = group.rank - local_q_seq_len = local_query.shape[1] - query_grad_buffer = paddle.zeros_like(local_query).to("float32") - key_grad_buffer = paddle.zeros_like(local_key).to("float32") - value_grad_buffer = paddle.zeros_like(local_value).to("float32") + query_grad_buffer = paddle.zeros_like(local_query) + key_grad_buffer = paddle.zeros_like(local_key) + value_grad_buffer = paddle.zeros_like(local_value) kv_comm_buffer = RingCommunicator(group, local_key, local_value) grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) if is_causal: - local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() - local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :].clone().contiguous() - lse_second_chunk = lse[:, :, local_q_seq_len // 2 :].clone().contiguous() - out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :].clone().contiguous() + local_query_second_chunk = paddle.slice( + local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + local_out_second_chunk = paddle.slice( + local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) + out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) + query_grad_buffer_second_chunk = paddle.slice( + query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) @@ -266,12 +294,12 @@ def balanced_ring_flash_attention_bwd_func( dropout, False, ) - query_grad_buffer[:, local_q_seq_len // 2 :, :, :] += block_q_grad + query_grad_buffer_second_chunk += block_q_grad else: block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( local_query, - block_k[:, : local_q_seq_len // 2, :, :], - block_v[:, : local_q_seq_len // 2, :, :], + k_cache[step], + v_cache[step], local_out, lse, fixed_seed_offset, @@ -282,10 +310,7 @@ def balanced_ring_flash_attention_bwd_func( ) query_grad_buffer += block_q_grad - # if step != cp_size - 1: - # kv_comm_buffer.wait() - # if step != 0: - # grad_comm_buffer.wait() + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 paddle.device.synchronize() grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) @@ -319,10 +344,10 @@ def forward( if attn_mask is not None: is_causal = False - out, lse = balanced_ring_flash_attention_fwd_func( + out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func( group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training ) - ctx.save_for_backward(query, key, value, out, lse, attn_mask) + ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache) ctx.group = group ctx.fixed_seed_offset = fixed_seed_offset ctx.dropout = dropout @@ -331,106 +356,31 @@ def forward( @staticmethod def backward(ctx, out_grad): - query, key, value, out, lse, attn_mask = ctx.saved_tensor() + query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor() group = ctx.group fixed_seed_offset = ctx.fixed_seed_offset dropout = ctx.dropout is_causal = ctx.is_causal if fixed_seed_offset is None: - fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() + fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64) query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func( - group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal + group, + k_cache, + v_cache, + out_grad, + query, + key, + value, + out, + lse, + fixed_seed_offset, + attn_mask, + dropout, + is_causal, ) if attn_mask is not None and not attn_mask.stop_gradient: return query_grad, key_grad, value_grad, None else: return query_grad, key_grad, value_grad - - -import unittest - - -class TestRingFlashAttention(unittest.TestCase): - def setUp(self): - paddle.distributed.init_parallel_env() - self.group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") - self.degree = self.group.world_size - self.rank = self.group.rank - - seed = 42 - random.seed(seed) - np.random.seed(seed) - paddle.seed(seed) - - def generate_full_data(self, batch_size, seq_len, num_head, head_dim): - query = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - key = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - value = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - - query.stop_gradient = False - key.stop_gradient = False - value.stop_gradient = False - - return query, key, value - - def split_belanced_data(self, input): - sliced_datas = paddle.split(input, num_or_sections=self.degree * 2, axis=1) - sliced_data0, sliced_data1 = sliced_datas[self.rank], sliced_datas[self.degree * 2 - 1 - self.rank] - return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() - - def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask): - query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim) - - local_query = self.split_belanced_data(query) - local_key = self.split_belanced_data(key) - local_value = self.split_belanced_data(value) - - local_query.stop_gradient = False - local_key.stop_gradient = False - local_value.stop_gradient = False - - if use_mask: - mask_shape = (1, 1, query.shape[1], query.shape[1]) - mask = np.random.random(mask_shape) - attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) - attn_mask = paddle.ones(mask_shape).to(query.dtype) - attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=self.degree * 2) - first_chunk_id, second_chunk_id = get_chunk_id(self.rank, self.degree) - local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) - else: - attn_mask = None - local_attn_mask = None - - local_out = RingFlashAttention.apply( - local_query, local_key, local_value, self.group, is_causal=is_causal, attn_mask=local_attn_mask - ) - ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask) - ref_local_out = self.split_belanced_data(ref_out) - np.testing.assert_allclose(local_out.numpy(), ref_local_out.numpy(), rtol=5e-03, atol=1e-03) - - local_out.backward() - ref_out.backward() - - ref_local_query_grad = self.split_belanced_data(query.grad) - ref_local_key_grad = self.split_belanced_data(key.grad) - ref_local_value_grad = self.split_belanced_data(value.grad) - - np.testing.assert_allclose(local_query.grad.numpy(), ref_local_query_grad.numpy(), rtol=5e-03, atol=1e-03) - np.testing.assert_allclose(local_key.grad.numpy(), ref_local_key_grad.numpy(), rtol=5e-03, atol=1e-03) - np.testing.assert_allclose(local_value.grad.numpy(), ref_local_value_grad.numpy(), rtol=5e-03, atol=1e-03) - - def test_normal_flash_attention(self): - self.single_test(1, 256, 1, 256, False, False) - - def test_masked_flash_attention(self): - self.single_test(1, 256, 1, 256, False, True) - - def test_casual_flash_attention(self): - self.single_test(1, 256, 1, 256, True, False) - - -if __name__ == "__main__": - unittest.main() -# python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 ring_flash_attention.py diff --git a/paddlenlp/transformers/ring_flash_attention_back_up.py b/paddlenlp/transformers/ring_flash_attention_back_up.py deleted file mode 100644 index 5d6266967887..000000000000 --- a/paddlenlp/transformers/ring_flash_attention_back_up.py +++ /dev/null @@ -1,664 +0,0 @@ -# paddlenlp/transformers/ring_attention.py - -import paddle -import paddle.distributed as dist -from paddle import _C_ops -from paddle.nn.functional.flash_attention import scaled_dot_product_attention -from paddle.autograd.py_layer import PyLayer -from custom_setup_ops import flash_attn_bwd -from paddle.framework import core - -import random -import numpy as np - -class RingCommunicator: - def __init__(self, group, local_key, local_value): - self._k_buffer = [paddle.zeros_like(local_key) for _ in range(2)] - self._v_buffer = [paddle.zeros_like(local_value) for _ in range(2)] - - self._k_buffer[0] = local_key.clone() - self._v_buffer[0] = local_value.clone() - - self._next_buffer_idx = 0 - - self.group = group - self.group_rank = group.rank - self.send_rank = self.group.ranks[(self.group_rank + 1) % self.group.world_size] - self.recv_rank = self.group.ranks[(self.group_rank - 1) % self.group.world_size] - - self._reqs = [] - - def wait(self): - # for req in self._reqs: - # req.wait() - # self._reqs = None - paddle.device.synchronize() - - def add_to_buffers(self, key, value): - if key.shape != self._k_buffer[self._next_buffer_idx].shape: - self._k_buffer[self._next_buffer_idx][:, :key.shape[1], :, :] += key - self._v_buffer[self._next_buffer_idx][:, :key.shape[1], :, :] += value - else: - self._k_buffer[self._next_buffer_idx] += key - self._v_buffer[self._next_buffer_idx] += value - - def get_buffers(self): - return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx] - - def send_recv(self): - send_k_op = dist.P2POp(dist.isend, self._k_buffer[self._next_buffer_idx], self.send_rank, self.group) - send_v_op = dist.P2POp(dist.isend, self._v_buffer[self._next_buffer_idx], self.send_rank, self.group) - recv_k_op = dist.P2POp(dist.irecv, self._k_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) - recv_v_op = dist.P2POp(dist.irecv, self._v_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) - - self._next_buffer_idx = (self._next_buffer_idx + 1) % 2 - - ops = [send_k_op, send_v_op, recv_k_op, recv_v_op] - - self._reqs = dist.batch_isend_irecv(ops) - - -def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): - if second_chunk_only: - second_chunk_out = old_out[:,old_out.shape[1]//2:, :, :] - second_chunk_lse = old_lse[:,old_lse.shape[1]//2:, :, :] - second_chunk_out, second_chunk_lse = update_out_and_lse(second_chunk_out, second_chunk_lse, block_out, block_lse) - old_out[:,old_out.shape[1]//2:, :, :] = second_chunk_out - old_lse[:,old_lse.shape[1]//2:, :, :] = second_chunk_lse - return old_out, old_lse - else: - lse = paddle.log(1 + paddle.exp(block_lse - old_lse)) + old_lse - return old_out * paddle.exp(old_lse - lse) + block_out * paddle.exp(block_lse - lse), lse - - -def get_chunk_id(rank, cp_size): - return rank, (2 * cp_size - 1 - rank) - - -def concat_masks(attn_masks_list, rank, cp_size): - assert len(attn_masks_list) == 2 * cp_size - first_chunk_id, second_chunk_id = get_chunk_id(rank, cp_size) - return paddle.concat([attn_masks_list[first_chunk_id], attn_masks_list[second_chunk_id]], axis=3) - - -def balanced_ring_flash_attention_fwd_funcV2(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): - cp_size = group.world_size - rank = group.rank - - comm_buffer = RingCommunicator(group, local_key, local_value) - local_q_seq_len = local_query.shape[1] - - computation_streams = [paddle.device.Stream(), paddle.device.current_stream()] - update_out_and_lse_done = paddle.device.Event() - block_out_buffer = [paddle.zeros_like(local_query) for _ in range(2)] - block_lse_buffer = [paddle.zeros([local_query.shape[0], local_query.shape[2], local_query.shape[1]], dtype=local_query.dtype) for _ in range(2)] - paddle.device.synchronize() - - if attn_mask is not None: - attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) - if is_causal: - local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() - for step in range(cp_size + 1): - block_k, block_v = comm_buffer.get_buffers() - - if step != cp_size - 1: - comm_buffer.send_recv() - - if not is_causal: - # out [bs, seq, nhead, headdim] - # lse [bs, nhead, seq] - if step < cp_size: - with paddle.device.stream_guard(computation_streams[step % 2]): - block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( - local_query, - block_k, - block_v, - fixed_seed_offset, - None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), - dropout, - False, - False, - not training, - "" - ) - if step > 0: - if step > 1: - computation_streams[(step - 1)% 2].wait_event(update_out_and_lse_done) - with paddle.device.stream_guard(computation_streams[(step - 1)% 2]): - core.nvprof_nvtx_push(f"update_out_and_lse step {step}") - block_out = block_out_buffer[(step - 1) % 2] - block_lse = block_lse_buffer[(step - 1) % 2][:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - if step - 1 == 0: - out, lse= block_out, block_lse - else: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - core.nvprof_nvtx_pop() - if step < cp_size: - computation_streams[(step - 1)% 2].record_event(update_out_and_lse_done) - else: - if step < cp_size: - with paddle.device.stream_guard(computation_streams[step % 2]): - if step == 0: - block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( - local_query, - block_k, - block_v, - fixed_seed_offset, - None, - dropout, - True, - False, - not training, - "") - elif step > rank: - block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( - local_query_second_chunk, - block_k, - block_v, - fixed_seed_offset, - None, - dropout, - False, - False, - not training, - "" - ) - else: - block_out_buffer[step % 2], _, block_lse_buffer[step % 2], _ = _C_ops.flash_attn( - local_query, - block_k[:, :local_q_seq_len // 2, :, :], - block_v[:, :local_q_seq_len // 2, :, :], - fixed_seed_offset, - None, - dropout, - False, - False, - not training, - "" - ) - if step > 0: - if step > 1: - computation_streams[(step - 1)% 2].wait_event(update_out_and_lse_done) - with paddle.device.stream_guard(computation_streams[(step - 1)% 2]): - core.nvprof_nvtx_push(f"update_out_and_lse step {step}") - block_out = block_out_buffer[(step - 1) % 2] - block_lse = block_lse_buffer[(step - 1) % 2][:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - if step - 1 == 0: - out, lse= block_out, block_lse - elif step - 1 > rank: - block_lse = block_lse[:, :(local_q_seq_len//2), :, :] - out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) - else: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - core.nvprof_nvtx_pop() - if step < cp_size: - computation_streams[(step - 1)% 2].record_event(update_out_and_lse_done) - # if step != cp_size - 1: - # comm_buffer.wait() - paddle.device.synchronize() - - out = out.to(local_query.dtype) - lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) - return out, lse - -def balanced_ring_flash_attention_fwd_func(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): - cp_size = group.world_size - rank = group.rank - - comm_buffer = RingCommunicator(group, local_key, local_value) - local_q_seq_len = local_query.shape[1] - - if attn_mask is not None: - attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) - if is_causal: - local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() - for step in range(cp_size): - block_k, block_v = comm_buffer.get_buffers() - - if step != cp_size - 1: - comm_buffer.send_recv() - - if not is_causal: - # out [bs, seq, nhead, headdim] - # lse [bs, nhead, seq] - block_out, _, block_lse, _ = _C_ops.flash_attn( - local_query, - block_k, - block_v, - fixed_seed_offset, - None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), - dropout, - False, - False, - not training, - "") - block_lse = block_lse[:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - - if step == 0: - out, lse= block_out, block_lse - else: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - if step == 0: - block_out, _, block_lse, _ = _C_ops.flash_attn( - local_query, - block_k, - block_v, - fixed_seed_offset, - None, - dropout, - True, - False, - not training, - "") - block_lse = block_lse[:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - out, lse= block_out, block_lse - elif step > rank: - block_out, _, block_lse, _ = _C_ops.flash_attn( - local_query_second_chunk, - block_k, - block_v, - fixed_seed_offset, - None, - dropout, - False, - False, - not training, - "" - ) - block_lse = block_lse[:, :, 0:(local_q_seq_len//2)] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - core.nvprof_nvtx_push("update_out_and_lse") - out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) - core.nvprof_nvtx_pop() - else: - block_out, _, block_lse, _ = _C_ops.flash_attn( - local_query, - block_k[:, :local_q_seq_len // 2, :, :], - block_v[:, :local_q_seq_len // 2, :, :], - fixed_seed_offset, - None, - dropout, - False, - False, - not training, - "" - ) - block_lse = block_lse[:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - core.nvprof_nvtx_push("update_out_and_lse") - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - core.nvprof_nvtx_pop() - - if step != cp_size - 1: - comm_buffer.wait() - - out = out.to(local_query.dtype) - lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) - return out, lse - - -def balanced_ring_flash_attention_bwd_func(group, out_grad, local_query, local_key, local_value, local_out, lse, fixed_seed_offset, attn_mask, dropout=0.0, is_causal=False): - cp_size = group.world_size - rank = group.rank - - local_q_seq_len = local_query.shape[1] - - query_grad_buffer = paddle.zeros_like(local_query).to("float32") - key_grad_buffer = paddle.zeros_like(local_key).to("float32") - value_grad_buffer = paddle.zeros_like(local_value).to("float32") - - kv_comm_buffer = RingCommunicator(group, local_key, local_value) - grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) - - if is_causal: - local_query_second_chunk = local_query[:, local_q_seq_len // 2:, :, :].clone().contiguous() - local_out_second_chunk = local_out[:, local_q_seq_len // 2:, :, :].clone().contiguous() - lse_second_chunk = lse[:, :, local_q_seq_len // 2:].clone().contiguous() - out_grad_second_chunk = out_grad[:, local_q_seq_len // 2:, :, :].clone().contiguous() - - - if attn_mask is not None: - attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) - - for step in range(cp_size): - block_k, block_v = kv_comm_buffer.get_buffers() - - if step != cp_size - 1: - kv_comm_buffer.send_recv() - - if not is_causal: - block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( - local_query, - block_k, - block_v, - local_out, - lse, - fixed_seed_offset, - None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), - out_grad, - dropout, - False) - query_grad_buffer += block_q_grad - else: - if step == 0: - block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( - local_query, - block_k, - block_v, - local_out, - lse, - fixed_seed_offset, - None, - out_grad, - dropout, - True) - query_grad_buffer += block_q_grad - elif step > rank: - block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( - local_query_second_chunk, - block_k, - block_v, - local_out_second_chunk, - lse_second_chunk, - fixed_seed_offset, - None, - out_grad_second_chunk, - dropout, - False) - query_grad_buffer[:, local_q_seq_len // 2:, :, :] += block_q_grad - else: - block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( - local_query, - block_k[:, :local_q_seq_len // 2, :, :], - block_v[:, :local_q_seq_len // 2, :, :], - local_out, - lse, - fixed_seed_offset, - None, - out_grad, - dropout, - False) - query_grad_buffer += block_q_grad - - # if step != cp_size - 1: - # kv_comm_buffer.wait() - # if step != 0: - # grad_comm_buffer.wait() - paddle.device.synchronize() - - grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) - grad_comm_buffer.send_recv() - - grad_comm_buffer.wait() - key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() - - dtype = local_query.dtype - return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) - -def ring_flash_attention_fwd_func(group, local_query, local_key, local_value, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): - cp_size = group.world_size - - comm_buffer = RingCommunicator(group, local_key, local_value) - local_q_seq_len = local_query.shape[1] - - if attn_mask is not None: - cur_attn_masks = paddle.split(attn_mask, num_or_sections=cp_size, axis=3) - - for step in range(cp_size): - block_k, block_v = comm_buffer.get_buffers() - - if step != cp_size - 1: - comm_buffer.send_recv() - - if not is_causal or step <= group.rank: - _causal = is_causal and 0 == step - # out [bs, seq, nhead, headdim] - # lse [bs, nhead, seq] - - block_out, _, block_lse, _ = _C_ops.flash_attn( - local_query, - block_k, - block_v, - fixed_seed_offset, - None if attn_mask is None else cur_attn_masks[(group.rank - step) % cp_size], - dropout, - _causal, - False, - not training, - "") - block_lse = block_lse[:, :, 0:local_q_seq_len] - block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1) - - if step == 0: - out, lse= block_out, block_lse - else: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step != cp_size - 1: - comm_buffer.wait() - - out = out.to(local_query.dtype) - lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1]) - return out, lse - -def ring_flash_attention_bwd_func(group, out_grad, local_query, local_key, local_value, local_out, lse, fixed_seed_offset, attn_mask, dropout=0.0, is_causal=False): - cp_size = group.world_size - - query_grad_buffer = paddle.zeros_like(local_query).to("float32") - key_grad_buffer = paddle.zeros_like(local_key).to("float32") - value_grad_buffer = paddle.zeros_like(local_value).to("float32") - - kv_comm_buffer = RingCommunicator(group, local_key, local_value) - grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) - - if attn_mask is not None: - cur_attn_masks = paddle.split(attn_mask, num_or_sections=cp_size, axis=3) - - for step in range(cp_size): - block_k, block_v = kv_comm_buffer.get_buffers() - - if step != cp_size - 1: - kv_comm_buffer.send_recv() - - if step <= group.rank or not is_causal: - _causal = is_causal and step == 0 - block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( - local_query, - block_k, - block_v, - local_out, - lse, - fixed_seed_offset, - None if attn_mask is None else cur_attn_masks[(group.rank - step) % cp_size], - out_grad, - dropout, - _causal) - query_grad_buffer += block_q_grad - grad_comm_buffer.wait() - grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) - elif step != 0: - grad_comm_buffer.wait() - - grad_comm_buffer.send_recv() - - if step != cp_size - 1: - kv_comm_buffer.wait() - - grad_comm_buffer.wait() - key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() - - dtype = local_query.dtype - return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) - - -class RingFlashAttention(PyLayer): - @staticmethod - def forward(ctx, query, key, value, group=None, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): - if dropout > 0.0: - raise NotImplementedError("Dropout is not supported in ring attention yet.") - if group is None: - group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() - if attn_mask is not None: - is_causal = False - - out, lse = balanced_ring_flash_attention_fwd_func(group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training) - ctx.save_for_backward(query, key, value, out, lse, attn_mask) - ctx.group = group - ctx.fixed_seed_offset = fixed_seed_offset - ctx.dropout = dropout - ctx.is_causal = is_causal - return out - - @staticmethod - def backward(ctx, out_grad): - query, key, value, out, lse, attn_mask = ctx.saved_tensor() - group = ctx.group - fixed_seed_offset = ctx.fixed_seed_offset - dropout = ctx.dropout - is_causal = ctx.is_causal - - if fixed_seed_offset is None: - fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() - - query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal) - if attn_mask is not None and not attn_mask.stop_gradient: - return query_grad, key_grad, value_grad, None - else: - return query_grad, key_grad, value_grad - - -class BaselineRingFlashAttention(PyLayer): - @staticmethod - def forward(ctx, query, key, value, group=None, fixed_seed_offset=None, attn_mask=None, dropout=0.0, is_causal=False, training=True): - if dropout > 0.0: - raise NotImplementedError("Dropout is not supported in ring attention yet.") - if group is None: - group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() - - if attn_mask is not None: - is_causal = False - - out, lse = ring_flash_attention_fwd_func(group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training) - - ctx.save_for_backward(query, key, value, out, lse, attn_mask) - ctx.group = group - ctx.fixed_seed_offset = fixed_seed_offset - ctx.dropout = dropout - ctx.is_causal = is_causal - - return out - - @staticmethod - def backward(ctx, out_grad): - query, key, value, out, lse, attn_mask = ctx.saved_tensor() - group = ctx.group - fixed_seed_offset = ctx.fixed_seed_offset - dropout = ctx.dropout - is_causal = ctx.is_causal - - if fixed_seed_offset is None: - fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous() - - query_grad, key_grad, value_grad = ring_flash_attention_bwd_func(group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal) - if attn_mask is not None and not attn_mask.stop_gradient: - return query_grad, key_grad, value_grad, None - else: - return query_grad, key_grad, value_grad - - -import unittest -import time - -def generate_full_data(batch_size, seq_len, num_head, head_dim): - query = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - key = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - value = (paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.float32)).to("gpu", "float16") - - query.stop_gradient = False - key.stop_gradient = False - value.stop_gradient = False - - return query, key, value - -def split_belanced_data(input, rank, degree): - sliced_datas = paddle.split(input, num_or_sections=degree * 2, axis=1) - sliced_data0, sliced_data1 = sliced_datas[rank], sliced_datas[degree * 2 - 1 - rank] - return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() - -def test_new(): - paddle.distributed.init_parallel_env() - group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") - degree = group.world_size - rank = group.rank - - seed=42 - random.seed(seed) - np.random.seed(seed) - paddle.seed(seed) - - # query, key, value = generate_full_data(8, 512 * 2 * degree, 36, 256) - query, key, value = generate_full_data(1, 128 * 2 * degree, 1, 256) - is_causal = True - - local_query = split_belanced_data(query, rank, degree) - local_key = split_belanced_data(key, rank, degree) - local_value = split_belanced_data(value, rank, degree) - - local_query.stop_gradient = False - local_key.stop_gradient = False - local_value.stop_gradient = False - - mask_shape = (1, 1, query.shape[1], query.shape[1]) - mask = np.random.random(mask_shape) - attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) - attn_mask = paddle.ones(mask_shape).to(query.dtype) - attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=degree * 2) - first_chunk_id, second_chunk_id = get_chunk_id(rank, degree) - local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) - - local_out = RingFlashAttention.apply(local_query, local_key, local_value, group, is_causal=is_causal) - ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal) - ref_local_out = split_belanced_data(ref_out, rank, degree) - - np.testing.assert_allclose(local_out.numpy(), ref_local_out.numpy(), rtol=5e-03, atol=1e-03) - - local_out.backward() - ref_out.backward() - - ref_local_query_grad = split_belanced_data(query.grad, rank, degree) - ref_local_key_grad = split_belanced_data(key.grad, rank, degree) - ref_local_value_grad = split_belanced_data(value.grad, rank, degree) - - np.testing.assert_allclose(local_query.grad.numpy(), ref_local_query_grad.numpy(), rtol=5e-03, atol=1e-03) - np.testing.assert_allclose(local_key.grad.numpy(), ref_local_key_grad.numpy(), rtol=5e-03, atol=1e-03) - np.testing.assert_allclose(local_value.grad.numpy(), ref_local_value_grad.numpy(), rtol=5e-03, atol=1e-03) - - # return - - # epoch = 1000 - # start_time = time.time() - # for iter_id in range(0, epoch): - # if iter_id == 10: - # core.nvprof_start() - # if iter_id == 15: - # core.nvprof_stop() - # return - # core.nvprof_nvtx_push(f"Forward {iter_id}") - # temp_out = RingFlashAttention.apply(local_query, local_key, local_value, group, is_causal=is_causal) - # core.nvprof_nvtx_pop() - # core.nvprof_nvtx_push(f"Backward {iter_id}") - # temp_out.backward() - # core.nvprof_nvtx_pop() - - # end_time = time.time() - # execution_time = end_time - start_time - # print(f"RingFlashAttention执行时间: {execution_time}秒") - -if __name__ == "__main__": - # unittest.main() - test_new() diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 32cfec4b59de..e19a42f8a756 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -1111,5 +1111,16 @@ else echo "only one gpu:${cudaid1} is set, skip test" fi +} +ring_flash_attention(){ +cd ${nlp_dir} +echo "test ring_flash_attention, cudaid1:${cudaid1}, cudaid2:${cudaid2}" +if [[ ${cudaid1} != ${cudaid2} ]]; then + time (python -m paddle.distributed.launch tests/transformers/test_ring_flash_attention.py >${log_path}/ring_flash_attention) >>${log_path}/ring_flash_attention 2>&1 + print_info $? ring_flash_attention +else + echo "only one gpu:${cudaid1} is set, skip test" +fi + } $1 diff --git a/scripts/regression/run_ci.sh b/scripts/regression/run_ci.sh index 74d0b1957af8..0f7f6fdf5ab0 100644 --- a/scripts/regression/run_ci.sh +++ b/scripts/regression/run_ci.sh @@ -33,7 +33,7 @@ all_P0case_dic=(["waybill_ie"]=3 ["msra_ner"]=15 ["glue"]=2 ["bert"]=2 ["skep"]= ["ernie-ctm"]=5 ["distilbert"]=5 ["transformer"]=5 ["pet"]=5 ["efl"]=5 ["p-tuning"]=5 ["ernie-doc"]=20 ["transformer-xl"]=5 \ ["question_matching"]=5 ["ernie-csc"]=5 ["nptag"]=5 ["ernie-m"]=5 ["taskflow"]=5 ["clue"]=5 ["textcnn"]=5 \ ["fast_generation"]=10 ["ernie-3.0"]=5 ["ernie-layout"]=5 ["uie"]=5 ["ernie-health"]=5 ["llm"]=5 \ -["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5) +["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5 ["ring_flash_attention"]=5) #################################### # Insatll paddlepaddle-gpu install_paddle(){ diff --git a/tests/transformers/test_ring_flash_attention.py b/tests/transformers/test_ring_flash_attention.py new file mode 100644 index 000000000000..134d2f9c011a --- /dev/null +++ b/tests/transformers/test_ring_flash_attention.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import paddle +from paddle.nn.functional.flash_attention import scaled_dot_product_attention + +from paddlenlp.transformers.ring_flash_attention import RingFlashAttention, get_chunk_id + + +class TestRingFlashAttention(unittest.TestCase): + def setUp(self): + paddle.distributed.init_parallel_env() + self.group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") + self.degree = self.group.world_size + self.rank = self.group.rank + + seed = 42 + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + self.test_id = 0 + + def generate_full_data(self, batch_size, seq_len, num_head, head_dim): + query = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + key = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + value = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + return query, key, value + + def split_belanced_data(self, input): + sliced_datas = paddle.split(input, num_or_sections=self.degree * 2, axis=1) + sliced_data0, sliced_data1 = sliced_datas[self.rank], sliced_datas[self.degree * 2 - 1 - self.rank] + return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() + + def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask): + if self.degree < 2: + return + query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim) + + local_query = self.split_belanced_data(query) + local_key = self.split_belanced_data(key) + local_value = self.split_belanced_data(value) + + local_query.stop_gradient = False + local_key.stop_gradient = False + local_value.stop_gradient = False + + if use_mask: + mask_shape = (bsz, 1, query.shape[1], query.shape[1]) + mask = np.random.random(mask_shape) + attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) + attn_mask = paddle.ones(mask_shape).to(query.dtype) + attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=self.degree * 2) + first_chunk_id, second_chunk_id = get_chunk_id(self.rank, self.degree) + local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) + else: + attn_mask = None + local_attn_mask = None + + with paddle.amp.auto_cast(enable=True, dtype="bfloat16"): + local_out = RingFlashAttention.apply( + local_query, local_key, local_value, self.group, is_causal=is_causal, attn_mask=local_attn_mask + ) + ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask) + + local_out.mean().backward() + ref_out.mean().backward() + + ref_local_query_grad = self.split_belanced_data(query.grad) + ref_local_key_grad = self.split_belanced_data(key.grad) + ref_local_value_grad = self.split_belanced_data(value.grad) + + ref_local_out = self.split_belanced_data(ref_out) + + rtol = 1e-04 + atol = 5e-03 + np.testing.assert_allclose( + local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_query.grad.to("float32").numpy(), ref_local_query_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_key.grad.to("float32").numpy(), ref_local_key_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_value.grad.to("float32").numpy(), ref_local_value_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + + print(f"Test {self.test_id} passed!") + self.test_id += 1 + + def test_normal_flash_attention(self): + self.single_test(2, 1024, 2, 128, False, False) + + def test_masked_flash_attention(self): + self.single_test(2, 1024, 2, 128, False, True) + + def test_casual_flash_attention(self): + self.single_test(2, 1024, 2, 128, True, False) + + +if __name__ == "__main__": + unittest.main() From 94943a8b9b6e508c2f9132526b85f4666b3292ee Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 31 May 2024 12:32:42 +0800 Subject: [PATCH 3/7] using sep group --- paddlenlp/trainer/trainer.py | 6 +--- paddlenlp/trainer/training_args.py | 29 ++----------------- .../transformers/context_parallel_utils.py | 8 ++--- paddlenlp/transformers/llama/modeling.py | 7 ++--- .../transformers/ring_flash_attention.py | 2 +- 5 files changed, 11 insertions(+), 41 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index caa4cb14de53..8e1049ef4a6c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1011,11 +1011,7 @@ def _inner_training_loop( assert reshard_util.is_sharding_opt(self.optimizer) self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) - if ( - self.optimizer._dp_enable - or getattr(self.optimizer, "_sep_enable", False) - or getattr(self.optimizer, "_cp_enable", False) - ): + if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 8293fc912e70..a578f416cdcc 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1161,41 +1161,18 @@ def is_segment_parallel_supported(): logger.warning("segment parallel is not supported!!!, Ignore it.") return support_sep - def is_context_parallel_supported(): - import inspect - - members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)] - support_cp = "get_cp_parallel_world_size" in members - if not support_cp: - logger.warning("context parallel is not supported!!!, Ignore it.") - return support_cp - if self.hybrid_parallel_topo_order == "pp_first": - if is_context_parallel_supported(): - order = ["dp", "pp", "sharding", "sep", "cp", "mp"] - elif is_segment_parallel_supported(): + if is_segment_parallel_supported(): order = ["dp", "pp", "sharding", "sep", "mp"] else: order = ["dp", "pp", "sharding", "mp"] if self.hybrid_parallel_topo_order == "sharding_first": - if is_context_parallel_supported(): - order = ["dp", "sharding", "pp", "sep", "cp", "mp"] - elif is_segment_parallel_supported(): + if is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] - if is_context_parallel_supported(): - hybrid_configs = { - "dp_degree": self.data_parallel_degree, - "mp_degree": self.tensor_parallel_degree, - "pp_degree": self.pipeline_parallel_degree, - "sharding_degree": self.sharding_parallel_degree, - "sep_degree": self.sep_parallel_degree, - "cp_degree": self.cp_parallel_degree, - "order": order, - } - elif is_segment_parallel_supported(): + if is_segment_parallel_supported(): hybrid_configs = { "dp_degree": self.data_parallel_degree, "mp_degree": self.tensor_parallel_degree, diff --git a/paddlenlp/transformers/context_parallel_utils.py b/paddlenlp/transformers/context_parallel_utils.py index b89c020a4a94..f005f385d94b 100644 --- a/paddlenlp/transformers/context_parallel_utils.py +++ b/paddlenlp/transformers/context_parallel_utils.py @@ -32,8 +32,8 @@ def split_inputs_sequence_dim_load_balance(inputs, rank=None, degree=None): if degree is None and rank is None: _hcg = fleet.get_hybrid_communicate_group() - degree = _hcg.get_cp_parallel_world_size() - rank = _hcg.get_cp_parallel_rank() + degree = _hcg.get_sep_parallel_world_size() + rank = _hcg.get_sep_parallel_rank() assert isinstance(degree, int) and isinstance( rank, int ), f"degree:{type(degree)} and rank:{type(rank)} must be int" @@ -70,8 +70,8 @@ def split_inputs_sequence_dim(inputs, rank=None, degree=None): degree = _hcg.get_sep_parallel_world_size() rank = _hcg.get_sep_parallel_rank() if degree == 1: - degree = _hcg.get_cp_parallel_world_size() - rank = _hcg.get_cp_parallel_rank() + degree = _hcg.get_sep_parallel_world_size() + rank = _hcg.get_sep_parallel_rank() assert isinstance(degree, int) and isinstance( rank, int ), f"degree:{type(degree)} and rank:{type(rank)} must be int" diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index c526bce17f17..b08d47ef0905 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -937,7 +937,7 @@ def forward( position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) if self.context_parallel: batch_size, seq_length, _, _ = query_states.shape - group = fleet.get_hybrid_communicate_group().get_cp_parallel_group() + group = fleet.get_hybrid_communicate_group().get_sep_parallel_group() chunk_size = seq_length // 2 chunk_num = group.nranks * 2 rank = group.rank @@ -1673,12 +1673,9 @@ def forward(self, prediction_scores, masked_lm_labels): with paddle.amp.auto_cast(False): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - if self.config.sep_parallel_degree > 1: + if self.config.sep_parallel_degree > 1 or self.config.cp_parallel_degree > 1: _hcg = fleet.get_hybrid_communicate_group() masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) - if self.config.cp_parallel_degree > 1: - _hcg = fleet.get_hybrid_communicate_group() - masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_cp_parallel_group()) # skip ignore_index which loss == 0 # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] # loss = paddle.mean(masked_lm_loss) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 60fbf441d5eb..3ff4d9def8d8 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -340,7 +340,7 @@ def forward( if dropout > 0.0: raise NotImplementedError("Dropout is not supported in ring attention yet.") if group is None: - group = dist.fleet.get_hybrid_communicate_group().get_cp_parallel_group() + group = dist.fleet.get_hybrid_communicate_group().get_sep_parallel_group() if attn_mask is not None: is_causal = False From 812a13e3518995b533003d5b6af88ad13892ca9d Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 31 May 2024 12:36:59 +0800 Subject: [PATCH 4/7] fix --- paddlenlp/trainer/training_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index a578f416cdcc..b2653f0e3113 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1178,7 +1178,9 @@ def is_segment_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "sep_degree": self.sep_parallel_degree, + "sep_degree": self.sep_parallel_degree + if self.sep_parallel_degree > 1 + else self.cp_parallel_degree, "order": order, } else: From 16eaedd5db1edabdde6eb29ad20a9e4e0a78a82c Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 31 May 2024 12:45:13 +0800 Subject: [PATCH 5/7] fix --- .../transformers/context_parallel_utils.py | 37 ------------------- paddlenlp/transformers/llama/modeling.py | 7 ++-- 2 files changed, 4 insertions(+), 40 deletions(-) diff --git a/paddlenlp/transformers/context_parallel_utils.py b/paddlenlp/transformers/context_parallel_utils.py index f005f385d94b..7f8a69352764 100644 --- a/paddlenlp/transformers/context_parallel_utils.py +++ b/paddlenlp/transformers/context_parallel_utils.py @@ -62,40 +62,3 @@ def do_split_sequence_dim_load_balance(data, rank, degree): else: raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}") return res - - -def split_inputs_sequence_dim(inputs, rank=None, degree=None): - if degree is None and rank is None: - _hcg = fleet.get_hybrid_communicate_group() - degree = _hcg.get_sep_parallel_world_size() - rank = _hcg.get_sep_parallel_rank() - if degree == 1: - degree = _hcg.get_sep_parallel_world_size() - rank = _hcg.get_sep_parallel_rank() - assert isinstance(degree, int) and isinstance( - rank, int - ), f"degree:{type(degree)} and rank:{type(rank)} must be int" - if degree <= 1: - return inputs - - def do_split_sequence_dim(data, rank, degree): - if data is None: - return None - assert isinstance(data, paddle.Tensor), f"data should be paddle.Tensor, but is type:{type(data)}" - assert len(data.shape) == 2, f"data dims should be 2, but shaped: {data.shape}" - sliced_data = paddle.split(data, num_or_sections=degree, axis=-1)[rank] - return sliced_data - - if isinstance(inputs, paddle.Tensor): - return do_split_sequence_dim(inputs, rank, degree) - elif isinstance(inputs, dict): - res = {} - for k, tensor in inputs.items(): - res[k] = do_split_sequence_dim(tensor, rank, degree) - elif isinstance(inputs, list): - res = [] - for tensor in inputs: - res.append(do_split_sequence_dim(tensor, rank, degree)) - else: - raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}") - return res diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index b08d47ef0905..30cc87493ddd 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -768,7 +768,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): assert self.num_key_value_heads % config.sep_parallel_degree == 0 assert self.num_heads % config.sep_parallel_degree == 0 self.reshard_layer = ReshardLayer() - self.context_parallel = config.cp_parallel_degree > 1 + self.config = config def _init_rope(self): @@ -935,7 +935,7 @@ def forward( if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - if self.context_parallel: + if self.config.cp_parallel_degree > 1: batch_size, seq_length, _, _ = query_states.shape group = fleet.get_hybrid_communicate_group().get_sep_parallel_group() chunk_size = seq_length // 2 @@ -959,7 +959,7 @@ def forward( ) else: - if self.context_parallel: + if self.config.cp_parallel_degree > 1: kv_seq_len *= self.config.cp_parallel_degree if self.config.use_long_sequence_strategies: cos, sin = self.rotary_emb(seq_len=kv_seq_len) @@ -971,6 +971,7 @@ def forward( ) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] From e7c4b1ecc9c64e6cae3d444417bff88816bc5a01 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 31 May 2024 20:10:30 +0800 Subject: [PATCH 6/7] update --- docs/trainer.md | 10 +++++- llm/llama/run_trainer_tp2cp2.sh | 8 ++--- llm/run_pretrain.py | 6 ++-- paddlenlp/trainer/trainer.py | 8 ++--- paddlenlp/trainer/training_args.py | 31 ++++++++++--------- paddlenlp/transformers/configuration_utils.py | 2 +- paddlenlp/transformers/llama/fusion_ops.py | 14 ++++----- paddlenlp/transformers/llama/modeling.py | 20 ++++++------ 8 files changed, 53 insertions(+), 46 deletions(-) diff --git a/docs/trainer.md b/docs/trainer.md index a1dde0af4f94..55df827bb3d7 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 following config is support: enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now. - + --context_parallel_degree + 上下文并行是将训练数据在序列维度进行切分的并行方法。 + 该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。 + 默认值-1, 表示不启用上下文并行, + (`int`, 可选, 默认为 `-1`) + (注: 该方法需要修改模型结构, 目前支持LLAMA) + (注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用) + Context parallelism is a parallel method that segments training data in the sequence dimension. + This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates. --recompute 是否使用重计算训练。可以节省显存。 重新计算前向过程以获取梯度,减少中间变量显存. diff --git a/llm/llama/run_trainer_tp2cp2.sh b/llm/llama/run_trainer_tp2cp2.sh index 954d59c3100b..1a684191deea 100644 --- a/llm/llama/run_trainer_tp2cp2.sh +++ b/llm/llama/run_trainer_tp2cp2.sh @@ -33,17 +33,13 @@ unset PADDLE_ELASTIC_TIMEOUT max_seq_length=1024 -master=127.0.0.1 -port=36677 - -max_steps=10000 +max_steps=1000 log_dir=seq_${max_seq_length}_log echo "log_dir:${log_dir}" rm -rf $log_dir export PYTHONPATH=../../:$PYTHONPATH python -u -m paddle.distributed.launch \ - --master $master:$port \ --gpus "3,4,5,7" \ --log_dir "./$log_dir" \ run_pretrain.py \ @@ -78,7 +74,7 @@ python -u -m paddle.distributed.launch \ --recompute_use_reentrant true \ --data_cache "./data_cache" \ --pipeline_parallel_degree 1 \ - --cp_parallel_degree 2 \ + --context_parallel_degree 2 \ --tensor_parallel_degree 2 \ --sequence_parallel false \ --skip_profile_timer true \ diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index cd9d91a22320..0f0b3122baae 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -485,15 +485,15 @@ def main(): config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob config.sep_parallel_degree = training_args.sep_parallel_degree - config.cp_parallel_degree = training_args.cp_parallel_degree + config.context_parallel_degree = training_args.context_parallel_degree if config.sequence_parallel: assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." assert ( config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" assert ( - config.seq_length % config.cp_parallel_degree == 0 - ), f"seq_length:{config.seq_length} must be divisible by cp_parallel_degree {config.cp_parallel_degree}" + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 8e1049ef4a6c..7d5cc4a5ffc1 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -764,8 +764,8 @@ def train( trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size if self.args.sep_parallel_degree > 0: trainable_numel = trainable_numel // self.args.sep_parallel_degree - if self.args.cp_parallel_degree > 0: - trainable_numel = trainable_numel // self.args.cp_parallel_degree + if self.args.context_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.context_parallel_degree # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited # so, the trainable numel is a little bigger than real. logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") @@ -900,7 +900,7 @@ def _inner_training_loop( for step, inputs in enumerate(epoch_iterator): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) - if self.args.use_hybrid_parallel and self.args.cp_parallel_degree > 1: + if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1: inputs = split_inputs_sequence_dim_load_balance(inputs) self.timers and self.timers("read-data").stop() os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) @@ -1765,7 +1765,7 @@ def _wrap_model(self, model, training=True): in_sharding_parallel_mode = self.sharding is not None in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 in_sep_parallel_mode = self.args.sep_parallel_degree > 1 - in_cp_parallel_mode = self.args.cp_parallel_degree > 1 + in_cp_parallel_mode = self.args.context_parallel_degree > 1 # Multi-gpu training if ( diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index b2653f0e3113..ea9290024680 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -230,9 +230,9 @@ class TrainingArguments: The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. ) - cp_parallel_degree (`int`, *optional*, defaults to `-1`)( - The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to - data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. + context_parallel_degree (`int`, *optional*, defaults to `-1`)( + Context parallelism is a parallel method that segments training data in the sequence dimension. + This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates. ) data_parallel_config (`str`, *optional*)( Some additional configs which affect data parallel performance, we provide some option to config it. @@ -587,7 +587,7 @@ class TrainingArguments: ) }, ) - cp_parallel_degree: int = field( + context_parallel_degree: int = field( default=-1, metadata={ "help": ( @@ -931,7 +931,7 @@ def __post_init__(self): if world_size > 1: tensor_parallel_degree = max(self.tensor_parallel_degree, 1) sep_parallel_degree = max(self.sep_parallel_degree, 1) - cp_parallel_degree = max(self.cp_parallel_degree, 1) + context_parallel_degree = max(self.context_parallel_degree, 1) pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( @@ -941,7 +941,10 @@ def __post_init__(self): if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // ( - tensor_parallel_degree * sep_parallel_degree * cp_parallel_degree * pipeline_parallel_degree + tensor_parallel_degree + * sep_parallel_degree + * context_parallel_degree + * pipeline_parallel_degree ) sharding_parallel_degree = max(self.sharding_parallel_degree, 1) @@ -953,7 +956,7 @@ def __post_init__(self): sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree - * cp_parallel_degree + * context_parallel_degree * pipeline_parallel_degree ) @@ -962,14 +965,14 @@ def __post_init__(self): or tensor_parallel_degree > 1 or pipeline_parallel_degree > 1 or self.sep_parallel_degree > 1 - or self.cp_parallel_degree > 1 + or self.context_parallel_degree > 1 ): self.use_hybrid_parallel = True self.sharding_parallel_degree = sharding_parallel_degree self.tensor_parallel_degree = tensor_parallel_degree self.pipeline_parallel_degree = pipeline_parallel_degree self.sep_parallel_degree = sep_parallel_degree - self.cp_parallel_degree = cp_parallel_degree + self.context_parallel_degree = context_parallel_degree if not self.use_hybrid_parallel: self.sharding = [] @@ -977,7 +980,7 @@ def __post_init__(self): self.tensor_parallel_degree = -1 self.pipeline_parallel_degree = -1 self.sep_parallel_degree = -1 - self.cp_parallel_degree = -1 + self.context_parallel_degree = -1 if self.hybrid_parallel_topo_order is None: self.hybrid_parallel_topo_order = "pp_first" @@ -1180,7 +1183,7 @@ def is_segment_parallel_supported(): "sharding_degree": self.sharding_parallel_degree, "sep_degree": self.sep_parallel_degree if self.sep_parallel_degree > 1 - else self.cp_parallel_degree, + else self.context_parallel_degree, "order": order, } else: @@ -1264,7 +1267,7 @@ def is_segment_parallel_supported(): elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) - self.cp_parallel_degree = max(self.cp_parallel_degree, 1) + self.context_parallel_degree = max(self.context_parallel_degree, 1) self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( @@ -1276,7 +1279,7 @@ def is_segment_parallel_supported(): self.sharding_parallel_degree = world_size // ( self.tensor_parallel_degree * self.sep_parallel_degree - * self.cp_parallel_degree + * self.context_parallel_degree * self.pipeline_parallel_degree ) @@ -1289,7 +1292,7 @@ def is_segment_parallel_supported(): self.sharding_parallel_degree * self.tensor_parallel_degree * self.sep_parallel_degree - * self.cp_parallel_degree + * self.context_parallel_degree * self.pipeline_parallel_degree ) diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index 99957f2057e7..093ea32e3bf6 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -467,7 +467,7 @@ def __init__(self, **kwargs): self.tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0) # Parameters for sep and cp self.sep_parallel_degree = kwargs.pop("sep_parallel_degree", -1) - self.cp_parallel_degree = kwargs.pop("cp_parallel_degree", -1) + self.context_parallel_degree = kwargs.pop("context_parallel_degree", -1) # If set to True, this option is used with fleet.meta_parallel.ParallelCrossEntropy # to calculate cross-entropy loss for parallel model. self.tensor_parallel_output = kwargs.pop("tensor_parallel_output", False) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 24e48342cfa0..182663bdbc73 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -62,15 +62,15 @@ def fusion_rope( position_ids, past_key_value, rotary_emb, - cp_parallel_degree=-1, + context_parallel_degree=-1, ): if get_env_device() != "gcu": assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape - if cp_parallel_degree > 1: + if context_parallel_degree > 1: assert get_env_device() == "gpu", "context parallel only support cuda device for now" - kv_seq_len *= cp_parallel_degree + kv_seq_len *= context_parallel_degree if get_env_device() != "gcu": cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": @@ -156,7 +156,7 @@ def fusion_flash_attention( if version != "0.0.0" and version <= "2.5.2": if alibi is not None: raise ValueError("Flash Attention doesn't support alibi") - if config.cp_parallel_degree > 1: + if config.context_parallel_degree > 1: raise ValueError(f"Context parallel is not implemented in version {version}") attn_output, attn_weights = flash_attention( query_states, @@ -170,7 +170,7 @@ def fusion_flash_attention( alibi = alibi.reshape([bsz, num_heads, 1, -1]) attention_mask = attention_mask.cast(alibi.dtype) + alibi if get_env_device() == "npu": - if config.cp_parallel_degree > 1: + if config.context_parallel_degree > 1: raise ValueError("Context parallel is not implemented for npu") attn_output = core.eager._run_custom_op( "flash_attention_npu", @@ -186,7 +186,7 @@ def fusion_flash_attention( npu_is_casual, )[0] elif get_env_device() == "gcu": - if config.cp_parallel_degree > 1: + if config.context_parallel_degree > 1: raise ValueError("Context parallel is not implemented for gcu") attn_output = core.eager._run_custom_op( "fused_sdp_flash_attention_gcu", @@ -199,7 +199,7 @@ def fusion_flash_attention( True, )[0] else: - if config.cp_parallel_degree > 1: + if config.context_parallel_degree > 1: attn_output = RingFlashAttention.apply( query_states, key_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 30cc87493ddd..cdc1a2abe845 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -233,7 +233,7 @@ def scaled_dot_product_attention( # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] else: - if config.cp_parallel_degree > 1: + if config.context_parallel_degree > 1: raise ValueError("Context parallel requires `use_flash_attention=True`") # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] @@ -935,7 +935,7 @@ def forward( if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - if self.config.cp_parallel_degree > 1: + if self.config.context_parallel_degree > 1: batch_size, seq_length, _, _ = query_states.shape group = fleet.get_hybrid_communicate_group().get_sep_parallel_group() chunk_size = seq_length // 2 @@ -955,12 +955,12 @@ def forward( position_ids, past_key_value, self.rotary_emb, - self.config.cp_parallel_degree, + self.config.context_parallel_degree, ) else: - if self.config.cp_parallel_degree > 1: - kv_seq_len *= self.config.cp_parallel_degree + if self.config.context_parallel_degree > 1: + kv_seq_len *= self.config.context_parallel_degree if self.config.use_long_sequence_strategies: cos, sin = self.rotary_emb(seq_len=kv_seq_len) cos = cos[None, :, None, :] @@ -1529,7 +1529,7 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) - if self.config.cp_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): + if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") # embed positions if attention_mask is None: @@ -1674,7 +1674,7 @@ def forward(self, prediction_scores, masked_lm_labels): with paddle.amp.auto_cast(False): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - if self.config.sep_parallel_degree > 1 or self.config.cp_parallel_degree > 1: + if self.config.sep_parallel_degree > 1 or self.config.context_parallel_degree > 1: _hcg = fleet.get_hybrid_communicate_group() masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) # skip ignore_index which loss == 0 @@ -1747,9 +1747,9 @@ def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sep_parallel_degree > 1: assert seq_length % self.config.sep_parallel_degree == 0 seq_length = seq_length // self.config.sep_parallel_degree - if self.config.cp_parallel_degree > 1: - assert seq_length % self.config.cp_parallel_degree == 0 - seq_length = seq_length // self.config.cp_parallel_degree + if self.config.context_parallel_degree > 1: + assert seq_length % self.config.context_parallel_degree == 0 + seq_length = seq_length // self.config.context_parallel_degree hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) if tensor_parallel_output is None: From 26b7059a864548ecf2438921aad1854426c381da Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Mon, 3 Jun 2024 15:39:05 +0800 Subject: [PATCH 7/7] fix --- paddlenlp/trainer/training_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index ea9290024680..2aa77bdeefbb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -938,6 +938,10 @@ def __post_init__(self): world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." + assert not ( + sep_parallel_degree > 1 and context_parallel_degree > 1 + ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." + if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // (