diff --git a/tests/transformers/bert/test_modeling.py b/tests/transformers/bert/test_modeling.py index 81ef6ded0f2d..cf1f3baf1325 100644 --- a/tests/transformers/bert/test_modeling.py +++ b/tests/transformers/bert/test_modeling.py @@ -310,14 +310,12 @@ def create_and_check_for_question_answering( start_positions=sequence_labels, end_positions=sequence_labels, return_dict=self.return_dict) - if token_labels is not None: + if sequence_labels is not None: result = result[1:] - elif paddle.is_tensor(result): - result = [result] - self.parent.assertEqual(result[1].shape, + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length]) - self.parent.assertEqual(result[2].shape, + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length]) def create_and_check_for_sequence_classification(