Skip to content

Commit 4a58f61

Browse files
committed
add moe unittest
1 parent 8bed006 commit 4a58f61

File tree

9 files changed

+454782
-0
lines changed

9 files changed

+454782
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import numpy as np
18+
import pytest
19+
20+
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
21+
from tests.parallel_launch import TestMultipleGpus
22+
from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case
23+
from tests.trainer.test_unified_checkpoint import remove_ckpt, remove_logs
24+
from tests.trainer.trainer_utils import get_pretrain_arguments
25+
26+
environment_variables = {
27+
"NCCL_ALGO": "Tree",
28+
"NVIDIA_TF32_OVERRIDE": "0",
29+
"NCCL_IB_TIMEOUT": "22",
30+
"NCCL_DEBUG": "INFO",
31+
"FLAGS_embedding_deterministic": "1",
32+
"FLAGS_cudnn_deterministic": "1",
33+
"Flags_mp_aysnc_allreduce": "1",
34+
"Flags_skip_mp_c_identity": "1",
35+
"FLAGS_shard_norm_align_dp": "0",
36+
"FLAGS_shard_use_reduce": "1",
37+
"test_ci_no_save_model": "1",
38+
}
39+
40+
moe_arguments = {
41+
"model_name_or_path": "./tests/trainer/unified-ckpt-qwen2moe",
42+
"dataset_name_or_path": "./unified_checkpoint/peft_input/data/",
43+
"output_dir": "./unified_checkpoint/checkpoints/qwen2moe_sft_ckpts",
44+
"per_device_train_batch_size": 1,
45+
"gradient_accumulation_steps": 8,
46+
"per_device_eval_batch_size": 8,
47+
"eval_accumulation_steps": 16,
48+
"learning_rate": 3e-04,
49+
"max_steps": 10,
50+
"save_steps": 6,
51+
"warmup_steps": 30,
52+
"logging_steps": 1,
53+
"evaluation_strategy": "no",
54+
"save_strategy": "steps",
55+
"src_length": 1024,
56+
"max_length": 2048,
57+
"bf16": "true",
58+
"fp16_opt_level": "O2",
59+
"do_train": "true",
60+
"do_eval": "false",
61+
"disable_tqdm": "true",
62+
"eval_with_do_generation": "false",
63+
"recompute": "true",
64+
"recompute_granularity": "full",
65+
"save_total_limit": 1,
66+
"tensor_parallel_degree": 1,
67+
"pipeline_parallel_degree": 1,
68+
"sharding": "",
69+
"lora": "false",
70+
"zero_padding": "false",
71+
"use_flash_attention": "false",
72+
"unified_checkpoint": 1,
73+
"continue_training": 0,
74+
"sequence_parallel": 0,
75+
}
76+
77+
78+
def check_acc(log_dir="log"):
79+
file_path = os.path.join(log_dir, "workerlog.n0.c0")
80+
cmd = "grep -a 'global_step: 10' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'"
81+
import subprocess
82+
83+
res = subprocess.check_output(cmd, shell=True, text=True)
84+
res = [float(x) for x in res.split()]
85+
86+
return res
87+
88+
89+
seed = 2024
90+
91+
rng = np.random.default_rng(seed=seed)
92+
93+
94+
@pytest.mark.xdist_group(name="UC")
95+
class TestUnifiedCheckpointBase(TestMultipleGpus):
96+
@classmethod
97+
@property
98+
def __test__(cls):
99+
return cls != TestUnifiedCheckpointBase
100+
101+
def setUp(self):
102+
"""
103+
1. update runfirst and rerun to run defined different config
104+
2. update need_allclose to True if you want to check the result
105+
3. update rtol to the relative value you want to check
106+
"""
107+
108+
self.configs = get_pretrain_arguments(moe_arguments)
109+
os.environ.update(environment_variables)
110+
111+
file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz"
112+
input_dir = "unified_checkpoint/peft_input/"
113+
os.makedirs(input_dir, exist_ok=True)
114+
file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz")
115+
if not os.path.exists(file_path):
116+
get_path_from_url_with_filelock(file_, root_dir=input_dir)
117+
118+
self.need_allclose = True
119+
self.rtol = 1e-7
120+
121+
self.run_file = "llm/run_finetune.py"
122+
123+
def runfirst(self, train_args):
124+
self.run_n1c8(self.run_file, **train_args)
125+
126+
def rerun(self, train_args):
127+
self.run_n1c8(self.run_file, **train_args)
128+
129+
@require_paddle_at_least_8_gpu
130+
def testTP4DP2(self):
131+
remove_logs()
132+
remove_ckpt(moe_arguments["output_dir"])
133+
134+
train_args = self.configs["TP4DP2"]
135+
self.runfirst(train_args)
136+
self.rerun(train_args)
137+
138+
if self.need_allclose:
139+
res = check_acc()
140+
assert len(res) == 2
141+
np.testing.assert_allclose(res[0], res[1], self.rtol)
142+
143+
@skip_for_none_ce_case
144+
@require_paddle_at_least_8_gpu
145+
def testTP2Sharding4(self):
146+
remove_logs()
147+
remove_ckpt(moe_arguments["output_dir"])
148+
149+
train_args = self.configs["TP2Sharding4"]
150+
self.runfirst(train_args)
151+
self.rerun(train_args)
152+
153+
if self.need_allclose:
154+
res = check_acc()
155+
assert len(res) == 2
156+
np.testing.assert_allclose(res[0], res[1], self.rtol)
157+
158+
159+
@pytest.mark.xdist_group(name="UC")
160+
class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase):
161+
@skip_for_none_ce_case
162+
@require_paddle_at_least_8_gpu
163+
def testTP2Sharding4V2(self):
164+
remove_logs()
165+
remove_ckpt(moe_arguments["output_dir"])
166+
167+
train_args = self.configs["TP2Sharding4"]
168+
train_args.update({"sharding_parallel_config": "split_param"})
169+
train_args.update({"amp_master_grad": True})
170+
self.runfirst(train_args)
171+
self.rerun(train_args)
172+
173+
if self.need_allclose:
174+
res = check_acc()
175+
assert len(res) == 2
176+
np.testing.assert_allclose(res[0], res[1], self.rtol)

tests/trainer/trainer_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ def get_pretrain_arguments(pretrain_arguments):
141141
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 8
142142
configs["DP8"] = train_args
143143

144+
train_args = copy.deepcopy(pretrain_arguments)
145+
train_args["tensor_parallel_degree"] = 2
146+
train_args["pipeline_parallel_degree"] = 1
147+
train_args["sharding_parallel_degree"] = 2
148+
train_args["sharding"] = "stage1"
149+
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 4
150+
configs["TP2DP2Sharding2"] = train_args
151+
144152
return configs
145153

146154

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"architectures": [
3+
"Qwen2MoeForCausalLM"
4+
],
5+
"attention_dropout": 0.0,
6+
"bos_token_id": 151643,
7+
"decoder_sparse_step": 1,
8+
"eos_token_id": 151643,
9+
"hidden_act": "silu",
10+
"hidden_size": 3584,
11+
"initializer_range": 0.02,
12+
"intermediate_size": 18944,
13+
"max_position_embeddings": 131072,
14+
"max_window_layers": 28,
15+
"model_type": "qwen2_moe",
16+
"moe_intermediate_size": 2560,
17+
"norm_topk_prob": false,
18+
"num_attention_heads": 28,
19+
"num_experts": 8,
20+
"num_experts_per_tok": 2,
21+
"num_hidden_layers": 8,
22+
"num_key_value_heads": 4,
23+
"output_router_logits": false,
24+
"rms_norm_eps": 1e-06,
25+
"rope_theta": 1000000.0,
26+
"router_aux_loss_coef": 0.001,
27+
"shared_expert_intermediate_size": 20480,
28+
"sliding_window": 131072,
29+
"tie_word_embeddings": false,
30+
"dtype": "bfloat16",
31+
"use_cache": true,
32+
"use_sliding_window": false,
33+
"vocab_size": 151936
34+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"bos_token_id": 151643,
3+
"pad_token_id": 151643,
4+
"eos_token_id": [
5+
151645,
6+
151643
7+
]
8+
}

0 commit comments

Comments
 (0)