Skip to content

Commit 839dd33

Browse files
committed
Add a test to special token addition.
1 parent 77360f6 commit 839dd33

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/transformers/test_tokenizer_common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,31 @@ def test_maximum_encoding_length_pair_input(self):
11561156

11571157
# self.assertEqual(encoded_masked, encoded_1)
11581158

1159+
def test_special_token_addition(self):
1160+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
1161+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
1162+
# Create tokenizer and add an additional special token
1163+
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
1164+
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
1165+
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
1166+
with tempfile.TemporaryDirectory() as tmp_dir:
1167+
tokenizer_1.save_pretrained(tmp_dir)
1168+
# Load the above tokenizer and add the same special token a second time
1169+
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
1170+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
1171+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
1172+
1173+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
1174+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
1175+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
1176+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
1177+
1178+
tokenizer_2.add_special_tokens(
1179+
{"additional_special_tokens": ["<tok>"]},
1180+
replace_additional_special_tokens=False,
1181+
)
1182+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
1183+
11591184
def test_special_tokens_mask(self):
11601185
tokenizers = self.get_tokenizers(do_lower_case=False)
11611186
for tokenizer in tokenizers:

0 commit comments

Comments
 (0)