diff --git a/llm/llama/auto_parallel/run_auto_sp.sh b/llm/llama/auto_parallel/run_auto_sp.sh new file mode 100644 index 000000000000..4e13d1fbfb21 --- /dev/null +++ b/llm/llama/auto_parallel/run_auto_sp.sh @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# just for debug auto_parallel + +set -x +unset CUDA_VISIBLE_DEVICES + +task_name="llama_auto_dp2mp2pp2_vpp2_sp" +# rm -rf output/$task_name/ # ckpt is saved in 'output/'' +rm -rf "output/$task_name""_log" + +export PARALLEL_CROSS_ENTROPY=true +export FLAGS_call_stack_level=2 +export PYTHONPATH=../../../:$PYTHONPATH +python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3,4,5,6,7" \ + --log_dir "output/$task_name""_log" \ + run_pretrain_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "output/$task_name" \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 8 \ + --gradient_accumulation_steps 8 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 2 \ + --virtual_pp_degree 2 \ + --pipeline_schedule_mode "VPP" \ + --sharding_parallel_degree 1 \ + --sharding "stage2" \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 10 \ + --save_steps 5000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --eval_steps 1000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 1 \ + --recompute_granularity full \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --parallel_mode "auto" \ + --sequence_parallel true \ + + # --resume_from_checkpoint "output/llama_auto_serial/checkpoint-2" \ diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index cf141d64870f..c4ee48def480 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -21,6 +21,7 @@ from typing import Optional, Tuple import paddle +import paddle.distributed as dist import paddle.nn.functional as F from paddle import nn from paddle.distributed import fleet @@ -362,10 +363,24 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + # enter tp region + if self.config.sequence_parallel: + mesh = get_mesh(self.ipp) + if "dp" in mesh.dim_names: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Replicate()], + ) if self.fuse_attention_qkv: target_shape = [0, 0, self.num_heads, 3 * self.head_dim] - fleet.auto.shard_tensor(self.qkv_proj.weight, *get_dist_attr([None, "mp"], self.ipp)) mix_layer = self.qkv_proj(hidden_states) @@ -383,6 +398,11 @@ def forward( key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape) value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape) + if self.config.sequence_parallel: + query_states = paddle.transpose(query_states, [1, 0, 2, 3]) + key_states = paddle.transpose(key_states, [1, 0, 2, 3]) + value_states = paddle.transpose(value_states, [1, 0, 2, 3]) + kv_seq_len = key_states.shape[-3] if past_key_value is not None: @@ -459,6 +479,22 @@ def forward( fleet.auto.shard_tensor(self.o_proj.weight, *get_dist_attr(["mp", None], self.ipp)) attn_output = self.o_proj(attn_output) + # enter sp region + if self.config.sequence_parallel: + attn_output = paddle.transpose(attn_output, [1, 0, 2]) + mesh = get_mesh(self.ipp) + if "dp" in mesh.dim_names: + attn_output = dist.reshard( + attn_output, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + else: + attn_output = dist.reshard( + attn_output, + get_mesh(self.ipp), + [dist.Shard(0)], + ) if not output_attentions: attn_weights = None @@ -565,7 +601,39 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + + # enter tp region + if self.config.sequence_parallel: + mesh = get_mesh(self.ipp) + if "dp" in mesh.dim_names: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Replicate()], + ) + hidden_states = self.mlp(hidden_states) + # enter sp region + if self.config.sequence_parallel: + mesh = get_mesh(self.ipp) + if "dp" in mesh.dim_names: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(0)], + ) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -830,6 +898,24 @@ def forward( ) # [bs, 1, seq_len, seq_len] hidden_states = inputs_embeds + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + emb_transpose = fleet.auto.shard_op(paddle.transpose, get_mesh(0)) + hidden_states = emb_transpose(hidden_states, [1, 0, 2]) + # enter sp region + mesh = get_mesh(0) + if "dp" in mesh.dim_names: + hidden_states = dist.reshard( + hidden_states, + get_mesh(0), + [dist.Shard(1), dist.Shard(0)], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(0), + [dist.Shard(0)], + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -838,7 +924,10 @@ def forward( for idx, (decoder_layer) in enumerate(self.layers): ipp = decoder_layer.ipp - fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp)) + if self.config.sequence_parallel: + fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["mp", "dp", None], ipp)) + else: + fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp)) decoder_layer = fleet.auto.shard_op(decoder_layer, get_mesh(ipp)) if output_hidden_states: @@ -846,6 +935,7 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient + if ( self.enable_recompute and idx not in self.no_recompute_layers @@ -1107,6 +1197,22 @@ def forward( ) hidden_states = outputs[0] # [bs, seq_len, dim] + # enter tp region + if self.config.sequence_parallel: + mesh = get_mesh(-1) + if "dp" in mesh.dim_names: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Replicate()], + ) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # if labels is None,means we need full output, instead of tensor_parallel_output # tensor_parallel_output is togather with ParallelCrossEntropy