Skip to content

[AutoParallel] Add Sequence Parallel for Static LLaMA #7746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions llm/llama/auto_parallel/run_auto_sp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# just for debug auto_parallel

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_auto_dp2mp2pp2_vpp2_sp"
# rm -rf output/$task_name/ # ckpt is saved in 'output/''
rm -rf "output/$task_name""_log"

export PARALLEL_CROSS_ENTROPY=true
export FLAGS_call_stack_level=2
export PYTHONPATH=../../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir "output/$task_name" \
--split 949,50,1 \
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
--fp16_opt_level "O2" \
--scale_loss 1024 \
--tensor_parallel_degree 2 \
--pipeline_parallel_degree 2 \
--virtual_pp_degree 2 \
--pipeline_schedule_mode "VPP" \
--sharding_parallel_degree 1 \
--sharding "stage2" \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 10 \
--save_steps 5000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 1 \
--recompute_granularity full \
--do_train \
--do_eval \
--device "gpu" \
--data_impl "mmap" \
--parallel_mode "auto" \
--sequence_parallel true \

# --resume_from_checkpoint "output/llama_auto_serial/checkpoint-2" \
110 changes: 108 additions & 2 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional, Tuple

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
Expand Down Expand Up @@ -362,10 +363,24 @@ def forward(
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Replicate()],
)

if self.fuse_attention_qkv:
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]

fleet.auto.shard_tensor(self.qkv_proj.weight, *get_dist_attr([None, "mp"], self.ipp))

mix_layer = self.qkv_proj(hidden_states)
Expand All @@ -383,6 +398,11 @@ def forward(
key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape)
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)

if self.config.sequence_parallel:
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
value_states = paddle.transpose(value_states, [1, 0, 2, 3])

kv_seq_len = key_states.shape[-3]

if past_key_value is not None:
Expand Down Expand Up @@ -459,6 +479,22 @@ def forward(
fleet.auto.shard_tensor(self.o_proj.weight, *get_dist_attr(["mp", None], self.ipp))
attn_output = self.o_proj(attn_output)

# enter sp region
if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
attn_output = dist.reshard(
attn_output,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)
else:
attn_output = dist.reshard(
attn_output,
get_mesh(self.ipp),
[dist.Shard(0)],
)
if not output_attentions:
attn_weights = None

Expand Down Expand Up @@ -565,7 +601,39 @@ def forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Replicate()],
)

hidden_states = self.mlp(hidden_states)
# enter sp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)
else:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(0)],
)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down Expand Up @@ -830,6 +898,24 @@ def forward(
) # [bs, 1, seq_len, seq_len]

hidden_states = inputs_embeds
if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
emb_transpose = fleet.auto.shard_op(paddle.transpose, get_mesh(0))
hidden_states = emb_transpose(hidden_states, [1, 0, 2])
# enter sp region
mesh = get_mesh(0)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(
hidden_states,
get_mesh(0),
[dist.Shard(1), dist.Shard(0)],
)
else:
hidden_states = dist.reshard(
hidden_states,
get_mesh(0),
[dist.Shard(0)],
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -838,14 +924,18 @@ def forward(

for idx, (decoder_layer) in enumerate(self.layers):
ipp = decoder_layer.ipp
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp))
if self.config.sequence_parallel:
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["mp", "dp", None], ipp))
else:
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp))
decoder_layer = fleet.auto.shard_op(decoder_layer, get_mesh(ipp))

if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient

if (
self.enable_recompute
and idx not in self.no_recompute_layers
Expand Down Expand Up @@ -1107,6 +1197,22 @@ def forward(
)

hidden_states = outputs[0] # [bs, seq_len, dim]
# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(-1)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(
hidden_states,
get_mesh(-1),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(
hidden_states,
get_mesh(-1),
[dist.Replicate()],
)
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is togather with ParallelCrossEntropy
Expand Down