Skip to content

Commit 5f0adfa

Browse files
committed
add tests for llama with GQA
1 parent 9188a56 commit 5f0adfa

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

llm/llama/tests/test_GQA.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
import paddle.distributed.fleet as fleet
20+
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel
21+
22+
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoModelForCausalLMPipe,AutoTokenizer
23+
24+
25+
class TestLlama(unittest.TestCase):
26+
def test_sequence_model(self):
27+
world_size = paddle.distributed.get_world_size()
28+
pp_degree = world_size
29+
tp_degree = 1
30+
31+
if world_size > 2:
32+
pp_degree = 2
33+
assert world_size % pp_degree == 0
34+
tp_degree = world_size // pp_degree
35+
36+
strategy = fleet.DistributedStrategy()
37+
strategy.hybrid_configs = {
38+
"dp_degree": 1,
39+
"mp_degree": tp_degree,
40+
"pp_degree": pp_degree,
41+
"sharding_degree": 1,
42+
}
43+
#strategy.pipeline_configs = {"enable_partial_send_recv": False if pp_degree > 1 else True}
44+
fleet.init(is_collective=True, strategy=strategy)
45+
hcg = fleet.get_hybrid_communicate_group()
46+
mp_group = hcg.get_model_parallel_group()
47+
tensor_parallel_rank = mp_group.rank
48+
49+
if pp_degree > 1:
50+
model_class = AutoModelForCausalLMPipe
51+
else:
52+
model_class = AutoModelForCausalLM
53+
54+
model_name_or_path = "meta-llama/Llama-2-7b"
55+
56+
seq_len = 2048
57+
batch_size = 2
58+
59+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
60+
config = AutoConfig.from_pretrained(model_name_or_path)
61+
config.seq_length = seq_len
62+
config.num_key_value_heads = 8 # gqa
63+
config.max_position_embeddings = max(config.max_position_embeddings, seq_len)
64+
config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128)
65+
config.use_flash_attention = True
66+
config.use_fused_rope = True
67+
config.use_fused_rms_norm = True
68+
config.fuse_attention_qkv = True
69+
config.recompute_granularity = "full"
70+
config.virtual_pp_degree = 1
71+
config.use_recompute = False
72+
73+
config.tensor_parallel_degree = tp_degree
74+
config.tensor_parallel_rank = tensor_parallel_rank
75+
config.tensor_parallel_output = False
76+
config.sequence_parallel = False
77+
78+
config.fuse_sequence_parallel_allreduce = False
79+
80+
# hidden_size = 4096
81+
model = model_class.from_config(
82+
config,
83+
dtype="float16",
84+
)
85+
86+
model.eval()
87+
88+
input_ids = paddle.arange(100, 100 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len])
89+
labels = paddle.arange(101, 101 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len])
90+
91+
attention_mask = None
92+
if pp_degree > 1:
93+
pp_model = PipelineParallel(layers=model, hcg=hcg, strategy=strategy)
94+
pp_model.accumulate_steps = batch_size # for micro_batch_size * acc_steps == batch_size
95+
ret = pp_model.eval_batch(data=[input_ids, labels], compute_loss=True)
96+
else:
97+
ret = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
98+
ret = ret[0]
99+
100+
print(f"ret mp{tp_degree} pp{pp_degree}", ret.item())
101+
ret_mp_pp = ret.item()
102+
103+
104+
105+
106+
107+
if __name__ == "__main__":
108+
TestLlama().test_sequence_model()

0 commit comments

Comments
 (0)