Skip to content

Commit fbe3b78

Browse files
committed
add unittest
1 parent bcb5aea commit fbe3b78

File tree

3 files changed

+331
-2
lines changed

3 files changed

+331
-2
lines changed

paddlenlp/transformers/mixtral/modeling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,6 @@ def __init__(self, config: MixtralConfig):
10701070
)
10711071
self.norm = MixtralRMSNorm(config)
10721072

1073-
self.gradient_checkpointing = False
1074-
10751073
def get_input_embeddings(self):
10761074
return self.embed_tokens
10771075

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2020 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import unittest
18+
19+
import paddle
20+
21+
from paddlenlp.transformers import MixtralConfig, MixtralForCausalLM, MixtralModel
22+
from tests.transformers.test_configuration_common import ConfigTester
23+
from tests.transformers.test_generation_utils import GenerationTesterMixin
24+
from tests.transformers.test_modeling_common import (
25+
ModelTesterMixin,
26+
ids_tensor,
27+
random_attention_mask,
28+
)
29+
30+
31+
class MixtralModelTester:
32+
def __init__(
33+
self,
34+
parent,
35+
vocab_size=32000,
36+
hidden_size=64,
37+
num_hidden_layers=2,
38+
num_attention_heads=8,
39+
masked_softmax_fusion=True,
40+
layer_norm_epsilon=1e-5,
41+
initializer_range=0.02,
42+
is_training=True,
43+
use_cache=False,
44+
bos_token_id=1,
45+
eos_token_id=2,
46+
apply_residual_connection_post_layernorm=False,
47+
hidden_dropout=0.0,
48+
attention_dropout=0.0,
49+
attention_softmax_in_fp32=True,
50+
pretraining_tp=1, # TP rank used when training with megatron
51+
dtype="bfloat16",
52+
slow_but_exact=False,
53+
batch_size: int = 2,
54+
seq_length: int = 10,
55+
type_sequence_label_size=2,
56+
activation_function="gelu",
57+
num_labels=3,
58+
num_choices=4,
59+
scope=None,
60+
dropout=0.56,
61+
use_input_mask: bool = False,
62+
use_labels: bool = False,
63+
return_dict=False,
64+
):
65+
self.parent: MixtralModelTest = parent
66+
self.vocab_size = vocab_size
67+
self.hidden_size = hidden_size
68+
self.num_hidden_layers = num_hidden_layers
69+
self.num_attention_heads = num_attention_heads
70+
self.masked_softmax_fusion = masked_softmax_fusion
71+
self.layer_norm_epsilon = layer_norm_epsilon
72+
self.initializer_range = initializer_range
73+
self.is_training = is_training
74+
self.use_cache = use_cache
75+
self.bos_token_id = bos_token_id
76+
self.eos_token_id = eos_token_id
77+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
78+
self.hidden_dropout = hidden_dropout
79+
self.attention_dropout = attention_dropout
80+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
81+
self.pretraining_tp = pretraining_tp
82+
self.dtype = dtype
83+
self.slow_but_exact = slow_but_exact
84+
85+
self.batch_size = batch_size
86+
self.seq_length = seq_length
87+
self.type_sequence_label_size = type_sequence_label_size
88+
self.activation_function = activation_function
89+
self.num_labels = num_labels
90+
self.num_choices = num_choices
91+
self.scope = scope
92+
self.dropout = dropout
93+
94+
self.use_input_mask = use_input_mask
95+
self.use_labels = use_labels
96+
self.return_dict = return_dict
97+
98+
def prepare_config_and_inputs(self):
99+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64)
100+
101+
input_mask = None
102+
if self.use_input_mask:
103+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
104+
105+
sequence_labels = None
106+
token_labels = None
107+
choice_labels = None
108+
if self.use_labels:
109+
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
110+
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
111+
choice_labels = ids_tensor([self.batch_size], self.num_choices)
112+
113+
config = self.get_config()
114+
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
115+
116+
def get_config(self) -> MixtralConfig:
117+
return MixtralConfig(
118+
vocab_size=self.vocab_size,
119+
hidden_size=self.hidden_size,
120+
num_hidden_layers=self.num_hidden_layers,
121+
num_attention_heads=self.num_attention_heads,
122+
masked_softmax_fusion=self.masked_softmax_fusion,
123+
layer_norm_epsilon=self.layer_norm_epsilon,
124+
initializer_range=self.initializer_range,
125+
use_cache=self.use_cache,
126+
bos_token_id=self.bos_token_id,
127+
eos_token_id=self.eos_token_id,
128+
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
129+
hidden_dropout=self.hidden_dropout,
130+
attention_dropout=self.attention_dropout,
131+
attention_softmax_in_fp32=self.attention_softmax_in_fp32,
132+
pretraining_tp=self.pretraining_tp,
133+
dtype=self.dtype,
134+
slow_but_exact=self.slow_but_exact,
135+
activation_function=self.activation_function,
136+
)
137+
138+
def create_and_check_model(
139+
self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels
140+
):
141+
model = MixtralModel(config)
142+
model.eval()
143+
result = model(input_ids)
144+
self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size])
145+
146+
def create_and_check_model_attention_mask(
147+
self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels
148+
):
149+
model = MixtralModel(config)
150+
model.eval()
151+
attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length])
152+
result_2d = model(input_ids, attention_mask=attn_mask_2d)[0]
153+
batch, seq_length = input_ids.shape
154+
causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype))
155+
attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1)
156+
result_3d = model(input_ids, attention_mask=attn_mask_3d)[0]
157+
attn_mask_4d = attn_mask_3d.unsqueeze(1)
158+
result_4d = model(input_ids, attention_mask=attn_mask_4d)[0]
159+
result_no_attention_mask = model(input_ids, attention_mask=None)[0]
160+
# Assert non-padding tokens have the same logits with different attention_mask shape
161+
self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all())
162+
self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all())
163+
self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all())
164+
165+
def create_and_check_model_past_large_inputs(
166+
self,
167+
config: MixtralConfig,
168+
input_ids,
169+
input_mask,
170+
sequence_labels,
171+
token_labels,
172+
choice_labels,
173+
):
174+
model = MixtralModel(config)
175+
model.eval()
176+
177+
# first forward pass
178+
outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict)
179+
past_key_values = outputs.past_key_values if self.return_dict else outputs[2]
180+
181+
# create hypothetical multiple next token and extent to next_input_ids
182+
next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size)
183+
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
184+
185+
# append to next input_ids and
186+
next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
187+
next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1)
188+
189+
outputs = model(
190+
next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict
191+
)
192+
193+
output_from_no_past = outputs[2][0]
194+
195+
outputs = model(
196+
next_tokens,
197+
attention_mask=next_attention_mask,
198+
past_key_values=past_key_values,
199+
output_hidden_states=True,
200+
return_dict=self.return_dict,
201+
)
202+
203+
output_from_past = outputs[2][0]
204+
205+
# select random slice
206+
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
207+
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
208+
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
209+
210+
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
211+
212+
# test that outputs are equal for slice
213+
self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
214+
215+
def prepare_config_and_inputs_for_common(self):
216+
config_and_inputs = self.prepare_config_and_inputs()
217+
(
218+
config,
219+
input_ids,
220+
input_mask,
221+
sequence_labels,
222+
token_labels,
223+
choice_labels,
224+
) = config_and_inputs
225+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
226+
return config, inputs_dict
227+
228+
def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
229+
model = MixtralForCausalLM(config)
230+
model.eval()
231+
232+
result = model(
233+
input_ids,
234+
use_cache=True,
235+
labels=input_ids if self.parent.use_labels else None,
236+
return_dict=self.parent.return_dict,
237+
)
238+
if self.parent.use_labels:
239+
self.parent.assertIsInstance(result[0].item(), float)
240+
self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size])
241+
else:
242+
self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size])
243+
244+
def check_model_position_ids(self, config, input_ids, input_mask, *args):
245+
model = MixtralForCausalLM(config)
246+
model.eval()
247+
248+
result_no_position_id = model(
249+
input_ids,
250+
labels=input_ids if self.parent.use_labels else None,
251+
return_dict=self.parent.return_dict,
252+
)
253+
batch_size, seq_len = input_ids.shape
254+
position_ids = paddle.arange(seq_len).expand((batch_size, seq_len))
255+
result_position_id = model(
256+
input_ids,
257+
position_ids,
258+
labels=input_ids if self.parent.use_labels else None,
259+
return_dict=self.parent.return_dict,
260+
)
261+
if self.parent.use_labels:
262+
self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all())
263+
else:
264+
self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all())
265+
266+
267+
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
268+
base_model_class = MixtralModel
269+
return_dict = False
270+
use_labels = False
271+
use_test_model_name_list = False
272+
273+
all_model_classes = (MixtralModel, MixtralForCausalLM)
274+
all_generative_model_classes = {MixtralForCausalLM: (MixtralModel, "mixtral")}
275+
276+
def setUp(self):
277+
super().setUp()
278+
279+
self.model_tester = MixtralModelTester(self)
280+
self.config_tester = ConfigTester(self, config_class=MixtralConfig, vocab_size=256, hidden_size=24)
281+
282+
def _get_input_ids_and_config(self):
283+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
284+
285+
input_ids = inputs_dict[self.input_name]
286+
attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64)
287+
288+
max_batch_size = 2
289+
sequence_length = input_ids.shape[-1] // 2
290+
input_ids = input_ids[:max_batch_size, :sequence_length]
291+
attention_mask = attention_mask[:max_batch_size, :sequence_length]
292+
max_length = 3
293+
294+
return config, input_ids, attention_mask, max_length
295+
296+
def test_model(self):
297+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
298+
self.model_tester.create_and_check_model(*config_and_inputs)
299+
300+
def test_model_attention_mask(self):
301+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
302+
self.model_tester.create_and_check_model_attention_mask(*config_and_inputs)
303+
304+
def test_model_position_ids(self):
305+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
306+
self.model_tester.check_model_position_ids(*config_and_inputs)
307+
308+
def test_generate_without_input_ids(self):
309+
# this requires 4-D attention mask logic, which is not supported yet
310+
pass
311+
312+
def test_mixtral_lm_head_model(self):
313+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
314+
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
315+
316+
317+
if __name__ == "__main__":
318+
unittest.main()

0 commit comments

Comments
 (0)