From 0e6af02807d9ee8055bc9d4ebda9a5b269e4a11a Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Wed, 3 Apr 2024 11:50:29 +0800 Subject: [PATCH 01/20] add jinja template --- paddlenlp/transformers/tokenizer_utils.py | 139 ++++++++++++++++++++-- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 3620669fefe6..95fd3f0913c6 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -31,7 +31,8 @@ import numpy as np import paddle import six -from jinja2.exceptions import TemplateError +from jinja2 import Template +from jinja2.exceptions import TemplateError, TemplateSyntaxError from jinja2.sandbox import ImmutableSandboxedEnvironment from paddle.utils import try_import @@ -516,7 +517,7 @@ class ChatTemplate: @staticmethod @lru_cache() - def _compile_jinja_template(chat_template): + def _compile_jinja_template(chat_template) -> Template: def raise_exception(message): raise TemplateError(message) @@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str, raise ValueError( "The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]" ) - if len(conversations[-1]) > 1: logger.warning( f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}" @@ -623,7 +623,7 @@ class ChatTemplateMixin: def apply_chat_template( self, - conversation: List[List[str, str]] | str, + conversation: List[List[str, str] | Dict[str, str]] | str, tokenize: bool = True, context_data: Dict[str, Any] = {}, **tokenizer_kwargs @@ -638,6 +638,25 @@ def apply_chat_template( Returns: str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data """ + if not self.chat_template: + raise ValueError("chat_template is not set, please set chat_template first.") + elif isinstance(self.chat_template, Template): + query = self._apply_chat_template(conversation, context_data) + elif isinstance(self.chat_template, ChatTemplate): + query = self._apply_chat_template_paddle(conversation, context_data) + + if not tokenize: + return query + + # chat_template should not add special tokens + tokenizer_kwargs["add_special_tokens"] = False + return self(query, **tokenizer_kwargs) + + def _apply_chat_template_paddle( + self, + conversation: List[List[str, str]] | str, + context_data: Dict[str, Any] = {}, + ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: context_data = self.chat_template._init_context_data(context_data) if isinstance(conversation, str): @@ -649,12 +668,31 @@ def apply_chat_template( ) query = self.chat_template(conversation, context_data=context_data) - if not tokenize: - return query + return query - # chat_template should not add special tokens - tokenizer_kwargs["add_special_tokens"] = False - return self(query, **tokenizer_kwargs) + def _apply_chat_template( + self, + conversation: List[List[str, str] | Dict[str, str]] | str, + context_data: Dict[str, Any] = {}, + ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: + if isinstance(conversation, str): + conversation = [{"role": "user", "content": conversation}] + elif isinstance(conversation, list): + for index, item in enumerate(conversation): + if isinstance(item, dict): + break + elif isinstance(item, str): + if index % 2 == 0: + conversation[index] = {"role": "user", "content": item} + else: + conversation[index] = {"role": "assistant", "content": item} + else: + raise ValueError( + "apply_chat_template do not support appling batch conversations, " + "so you should apply the conversation one by one." + ) + query = self.chat_template.render(messages=conversation, **self.special_tokens_map) + return query def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): """Encodes conversation to pairs of token ids. @@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Returns: List[list[int], list[int]]: the pair of input_ids and target_ids """ + if not self.chat_template: + raise ValueError("chat_template is not set, please set chat_template first.") + elif isinstance(self.chat_template, Template): + query = self._encode_chat_inputs(conversations, context_data) + elif isinstance(self.chat_template, ChatTemplate): + query = self._encode_chat_inputs_paddle(conversations, context_data) + return query + + def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): + breakpoint() context_data = self.chat_template._init_context_data(context_data) # encode system result = {} @@ -692,6 +740,63 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: result["conversations"] = conversation_ids return result + def _encode_chat_inputs( + self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, add_generation_prompt=False + ): + result = {} + # conversation = [] + # if origin_msg[0]['role'] == 'system': + # system = origin_msg.pop(0) + # try: + # self.chat_template.render(system) + # except Exception as e: + # raise RuntimeError("System is not supported", e) + # else: + # system = None + + conversation = [] + origin_msg = [] + for round in conversations: + round_role = [ + {"role": "user", "content": round[0]}, + {"role": "assistant", "content": round[1]}, + ] + origin_msg.extend(round_role) + conversation.append(round_role) + ans = [] + system = None + for conv in conversation: + roundi = [system] + conv if system else conv + roundi_str = self.chat_template.render(messages=roundi, add_generation_prompt=add_generation_prompt) + roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] + roundi_no_ans_str = self.chat_template.render( + messages=roundi_no_ans, add_generation_prompt=add_generation_prompt + ) + ans_roundi = roundi_str[len(roundi_no_ans_str) :] + ans.append(ans_roundi) + + regex_pattern = "|".join(map(re.escape, ans)) + non_learnable_parts = re.split( + r"(?:%s)" % regex_pattern, + self.chat_template.render(messages=origin_msg, add_generation_prompt=add_generation_prompt), + ) + if non_learnable_parts[-1] == "": + non_learnable_parts.pop() + + assert len(non_learnable_parts) == len(ans) + + conversation_ids = [] + for i in range(len(non_learnable_parts)): + conversation_ids.append( + self.batch_encode([non_learnable_parts[i], ans[i]], add_special_tokens=False, padding=False)[ + "input_ids" + ] + ) + print(self.batch_decode(conversation_ids[i])) + + result["conversations"] = conversation_ids + return result + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): cache_dir = kwargs.pop("cache_dir", None) @@ -713,6 +818,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): if not os.path.exists(chat_template_file): return tokenizer + if tokenizer.chat_template is not None: + logger.warning( + "Chat-template already exists in config file, it will be overwritten by chat_template.json file." + ) tokenizer.init_chat_template(chat_template_file) return tokenizer @@ -724,9 +833,15 @@ def init_chat_template(self, chat_template: str | dict): """ if isinstance(chat_template, str): if not os.path.exists(chat_template): - raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template)) - - self.chat_template = ChatTemplate.from_file(chat_template) + try: + self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template) + except TemplateSyntaxError: + # It is neither jinjia string nor path string + raise TemplateSyntaxError( + "The chat-template in json is not valid jinja string: {}".format(chat_template) + ) + else: + self.chat_template = ChatTemplate.from_file(chat_template) elif isinstance(chat_template, dict): self.chat_template = ChatTemplate.from_dict(chat_template) elif isinstance(chat_template, ChatTemplate): From 518b9edcdd4a4a651855b1c3183d8e6b517c9a47 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Tue, 12 Mar 2024 16:34:15 +0800 Subject: [PATCH 02/20] add in cfg --- paddlenlp/transformers/tokenizer_utils_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index eeb99117a6d3..6c11df383f55 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -1574,6 +1574,9 @@ def convert_added_tokens(obj): # TODO(guosheng): avoid reduplication of position args and key word args tokenizer = cls(*init_args, **init_kwargs) + chat_template = init_kwargs.pop("chat_template", None) + if chat_template is not None: + tokenizer.init_chat_template(chat_template) special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) if special_tokens_map_file is not None: with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: From e50325f3ac8a8cfcf76c13865d819e4cc65767fb Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 10:26:02 +0800 Subject: [PATCH 03/20] add system --- paddlenlp/transformers/tokenizer_utils.py | 24 +++++++++++------------ 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 95fd3f0913c6..bd6f861308d9 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -715,7 +715,6 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: return query def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): - breakpoint() context_data = self.chat_template._init_context_data(context_data) # encode system result = {} @@ -747,14 +746,13 @@ def _encode_chat_inputs( # conversation = [] # if origin_msg[0]['role'] == 'system': # system = origin_msg.pop(0) - # try: - # self.chat_template.render(system) - # except Exception as e: - # raise RuntimeError("System is not supported", e) - # else: - # system = None - - conversation = [] + try: + self.chat_template.render({"role": "system", "content": ""}) + except Exception as e: + system = None + logger.debug(e) + + conversation_dict = [] origin_msg = [] for round in conversations: round_role = [ @@ -762,10 +760,10 @@ def _encode_chat_inputs( {"role": "assistant", "content": round[1]}, ] origin_msg.extend(round_role) - conversation.append(round_role) + conversation_dict.append(round_role) ans = [] - system = None - for conv in conversation: + + for conv in conversation_dict: roundi = [system] + conv if system else conv roundi_str = self.chat_template.render(messages=roundi, add_generation_prompt=add_generation_prompt) roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] @@ -792,7 +790,7 @@ def _encode_chat_inputs( "input_ids" ] ) - print(self.batch_decode(conversation_ids[i])) + # print(self.batch_decode(conversation_ids[i])) result["conversations"] = conversation_ids return result From c6ac05018d1d5d936bea645c4fbf8f93f052e920 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 10:29:49 +0800 Subject: [PATCH 04/20] add notes --- paddlenlp/transformers/tokenizer_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index bd6f861308d9..a52033feb0d7 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -749,6 +749,7 @@ def _encode_chat_inputs( try: self.chat_template.render({"role": "system", "content": ""}) except Exception as e: + # some tokenizer do not support chat_template, they raise error in jinja string system = None logger.debug(e) @@ -786,9 +787,11 @@ def _encode_chat_inputs( conversation_ids = [] for i in range(len(non_learnable_parts)): conversation_ids.append( - self.batch_encode([non_learnable_parts[i], ans[i]], add_special_tokens=False, padding=False)[ - "input_ids" - ] + self.batch_encode( + [non_learnable_parts[i], ans[i]], + add_special_tokens=False, + padding=False, + )["input_ids"] ) # print(self.batch_decode(conversation_ids[i])) From f1bc9352a911c3552382dae553f5ce980008f232 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 11:15:31 +0800 Subject: [PATCH 05/20] fix apply chat template --- paddlenlp/transformers/tokenizer_utils.py | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index a52033feb0d7..f0a8899015fc 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -674,24 +674,36 @@ def _apply_chat_template( self, conversation: List[List[str, str] | Dict[str, str]] | str, context_data: Dict[str, Any] = {}, + add_generation_prompt=True, ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: if isinstance(conversation, str): - conversation = [{"role": "user", "content": conversation}] + conversations = [{"role": "user", "content": conversation}] elif isinstance(conversation, list): + conversations = [] for index, item in enumerate(conversation): if isinstance(item, dict): + conversations = conversation break - elif isinstance(item, str): - if index % 2 == 0: - conversation[index] = {"role": "user", "content": item} + elif isinstance(item, list): + assert 1 <= len(item) <= 2 + if isinstance(item[0], str): + conversations.append({"role": "user", "content": item[0]}) + if len(item) == 2 and isinstance(item[1], str): + conversations.append({"role": "assistant", "content": item[1]}) + else: + # item里只有一个元素,说明为最后一轮 + if index != len(conversation) - 1: + raise ValueError(f"Round {index} has error round") else: - conversation[index] = {"role": "assistant", "content": item} + raise ValueError("Each round in list should be string") else: raise ValueError( "apply_chat_template do not support appling batch conversations, " "so you should apply the conversation one by one." ) - query = self.chat_template.render(messages=conversation, **self.special_tokens_map) + query = self.chat_template.render( + messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt + ) return query def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): From 89a5188740e7729eedd8ac0371ca65572527d72c Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 15:45:17 +0800 Subject: [PATCH 06/20] add generation flag --- llm/data.py | 4 ++-- paddlenlp/transformers/tokenizer_utils.py | 29 +++++++++++++---------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/llm/data.py b/llm/data.py index 5d44c72c8abd..24d83f0ba53c 100644 --- a/llm/data.py +++ b/llm/data.py @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args): return tokenized_source, tokenized_target_input_ids -def tokenize_rounds_example(tokenizer, example, data_args): +def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): """tokenize multi-rounds examples with chat_template.json Args: @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args): # 1. only tokenize input_ids conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs( - conversations, context_data=context_data + conversations, context_data=context_data, **kwargs ) system_ids = conversation_result.pop("system", []) or [] diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index f0a8899015fc..ffa0a4d416c7 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -641,7 +641,8 @@ def apply_chat_template( if not self.chat_template: raise ValueError("chat_template is not set, please set chat_template first.") elif isinstance(self.chat_template, Template): - query = self._apply_chat_template(conversation, context_data) + add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True) + query = self._apply_chat_template(conversation, context_data, add_generation_prompt=add_generation_prompt) elif isinstance(self.chat_template, ChatTemplate): query = self._apply_chat_template_paddle(conversation, context_data) @@ -706,7 +707,7 @@ def _apply_chat_template( ) return query - def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): + def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs): """Encodes conversation to pairs of token ids. Turn 0: bos + system + sep + user bot + eos Turn t: sep + bot + query bot + eos @@ -721,7 +722,8 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: if not self.chat_template: raise ValueError("chat_template is not set, please set chat_template first.") elif isinstance(self.chat_template, Template): - query = self._encode_chat_inputs(conversations, context_data) + add_generation_prompt = kwargs.pop("add_generation_prompt", True) + query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt) elif isinstance(self.chat_template, ChatTemplate): query = self._encode_chat_inputs_paddle(conversations, context_data) return query @@ -752,18 +754,21 @@ def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], contex return result def _encode_chat_inputs( - self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, add_generation_prompt=False + self, + conversations: List[List[str, str]], + context_data: Dict[str, Any] = {}, + system: str = None, + add_generation_prompt=True, ): result = {} # conversation = [] # if origin_msg[0]['role'] == 'system': # system = origin_msg.pop(0) - try: - self.chat_template.render({"role": "system", "content": ""}) - except Exception as e: - # some tokenizer do not support chat_template, they raise error in jinja string - system = None - logger.debug(e) + if system: + try: + self.chat_template.render(messages={"role": "system", "content": system}) + except Exception as e: + raise ValueError("System is not supported in this tokenizer.", e) conversation_dict = [] origin_msg = [] @@ -778,7 +783,7 @@ def _encode_chat_inputs( for conv in conversation_dict: roundi = [system] + conv if system else conv - roundi_str = self.chat_template.render(messages=roundi, add_generation_prompt=add_generation_prompt) + roundi_str = self.chat_template.render(messages=roundi, add_generation_prompt=False) roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] roundi_no_ans_str = self.chat_template.render( messages=roundi_no_ans, add_generation_prompt=add_generation_prompt @@ -789,7 +794,7 @@ def _encode_chat_inputs( regex_pattern = "|".join(map(re.escape, ans)) non_learnable_parts = re.split( r"(?:%s)" % regex_pattern, - self.chat_template.render(messages=origin_msg, add_generation_prompt=add_generation_prompt), + self.chat_template.render(messages=origin_msg, add_generation_prompt=False), ) if non_learnable_parts[-1] == "": non_learnable_parts.pop() From f35e5b33096c5bc41faac5047f871c68e6e73a34 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 15:42:26 +0800 Subject: [PATCH 07/20] update jinja ut --- tests/transformers/test_chat_template.py | 63 +++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 0d7e2e68ae9f..0736778f471e 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -95,7 +95,7 @@ def test_inference_template(self): self.assertEqual(final_query, expected_query) -class ChatTemplateIntegrationTest(unittest.TestCase): +class TemplateIntegrationTest(unittest.TestCase): def test_linlyai_chinese_llama_2_chat_template(self): tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b") query = "你好" @@ -282,3 +282,64 @@ def test_inference_template_with_context_data(self): final_query = tokenizer.apply_chat_template(query, context_data=context_data, tokenize=False) expected_query = "你是一个人工智能助手<>-<>\nHuman: 你好 Bot:" self.assertEqual(final_query, expected_query) + + +class ChatTemplateIntegrationTest(unittest.TestCase): + class DataArg: + def __init__(self, max_length, src_length: Optional[int] = None): + self.max_length: int = max_length + if src_length is None: + src_length = self.max_length - 8 + + self.src_length: int = src_length + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("qwen/qwen-7b-chat") + qwen_jinja = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + self.tokenizer.init_chat_template(qwen_jinja) + return super().setUp() + + def test_chat_template(self): + # test single turn + query = "你好" + final_query = self.tokenizer.apply_chat_template(query, tokenize=False) + expected_query = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n" + self.assertEqual(final_query, expected_query) + + # test multi turns conversation + query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"]] + final_query = self.tokenizer.apply_chat_template(query, tokenize=False) + expected_query = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n<|im_start|>user\n今天吃啥<|im_end|>\n<|im_start|>assistant\n" + self.assertEqual(final_query, expected_query) + + def test_system_error(self): + # test system messaage error + error_jinja = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}" + self.tokenizer.init_chat_template(error_jinja) + from jinja2.exceptions import TemplateError + + with self.assertRaises(TemplateError): + self.tokenizer.apply_chat_template([{"role": "system", "content": ""}]) + + def test_round_error(self): + query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"], ["你好", "您好"]] + with self.assertRaises(ValueError): + self.tokenizer.apply_chat_template(query, tokenize=False) + + def test_train_format(self): + from data import tokenize_rounds_example + + fake_data_args = self.DataArg(50, src_length=50) + example = {"src": ["你好"], "tgt": ["您好,我是个人人工智能助手"]} + result, tgt_id = tokenize_rounds_example(self.tokenizer, example, fake_data_args, add_generation_prompt=True) + sentence_result = self.tokenizer.decode(result["input_ids"]) + expected_sentence = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n" + self.assertEqual(expected_sentence, sentence_result) + + tgt_idx = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + self.assertEqual(tgt_id[tgt_idx - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx], -100) From c0cfbc7501a5bd923096c56f0adda82f224be014 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 15:47:29 +0800 Subject: [PATCH 08/20] fix error --- tests/transformers/test_chat_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 0736778f471e..8dbd30b2ddb1 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -95,7 +95,7 @@ def test_inference_template(self): self.assertEqual(final_query, expected_query) -class TemplateIntegrationTest(unittest.TestCase): +class ChatTemplateIntegrationTest(unittest.TestCase): def test_linlyai_chinese_llama_2_chat_template(self): tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b") query = "你好" @@ -284,7 +284,7 @@ def test_inference_template_with_context_data(self): self.assertEqual(final_query, expected_query) -class ChatTemplateIntegrationTest(unittest.TestCase): +class TemplateIntegrationTest(unittest.TestCase): class DataArg: def __init__(self, max_length, src_length: Optional[int] = None): self.max_length: int = max_length From 167057b1091a959564df3b992cb62d5b9b2206fe Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 16:15:08 +0800 Subject: [PATCH 09/20] add syntax error check --- tests/transformers/test_chat_template.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 8dbd30b2ddb1..85e1be26b902 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -326,6 +326,16 @@ def test_round_error(self): with self.assertRaises(ValueError): self.tokenizer.apply_chat_template(query, tokenize=False) + def test_jinja_syntax_error(self): + # test system messaage error + error_jinja = ( + "{ bos_token }{% if messages[0]['role'] == 'system' %}{ raise_exception('System role not supported')}" + ) + from jinja2.exceptions import TemplateSyntaxError + + with self.assertRaises(TemplateSyntaxError): + self.tokenizer.init_chat_template(error_jinja) + def test_train_format(self): from data import tokenize_rounds_example From d98369ebb85f04e63aeafebb349db615da58a430 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 7 Apr 2024 16:57:21 +0800 Subject: [PATCH 10/20] fix syntax error --- paddlenlp/transformers/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index ffa0a4d416c7..8522dcd47ad7 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -856,7 +856,8 @@ def init_chat_template(self, chat_template: str | dict): except TemplateSyntaxError: # It is neither jinjia string nor path string raise TemplateSyntaxError( - "The chat-template in json is not valid jinja string: {}".format(chat_template) + "The chat-template in json is not valid jinja string: {}".format(chat_template), + lineno=0, # fake lineno, useless required msg ) else: self.chat_template = ChatTemplate.from_file(chat_template) From 263a888f40ec16091aeee183a2ed8f3ebd931b4b Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Mon, 8 Apr 2024 09:40:35 +0800 Subject: [PATCH 11/20] refresh codev --- tests/transformers/test_chat_template.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 85e1be26b902..2c0bb192e814 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -322,6 +322,7 @@ def test_system_error(self): self.tokenizer.apply_chat_template([{"role": "system", "content": ""}]) def test_round_error(self): + # error round, 1 is not a valid role query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"], ["你好", "您好"]] with self.assertRaises(ValueError): self.tokenizer.apply_chat_template(query, tokenize=False) From 9677f25f6e6b2691f92a7f220a4ffdb41ebead16 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Tue, 9 Apr 2024 18:40:54 +0800 Subject: [PATCH 12/20] refresh codecov --- tests/transformers/test_chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 2c0bb192e814..5effff95f383 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -322,7 +322,7 @@ def test_system_error(self): self.tokenizer.apply_chat_template([{"role": "system", "content": ""}]) def test_round_error(self): - # error round, 1 is not a valid role + # error round, 1 is not a valid role. query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"], ["你好", "您好"]] with self.assertRaises(ValueError): self.tokenizer.apply_chat_template(query, tokenize=False) From 393f87f12e373243b4d33c181cb2176a9c039b66 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Fri, 12 Apr 2024 15:30:02 +0800 Subject: [PATCH 13/20] update special token map in render --- paddlenlp/transformers/tokenizer_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 8522dcd47ad7..b96d50256e70 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -783,10 +783,12 @@ def _encode_chat_inputs( for conv in conversation_dict: roundi = [system] + conv if system else conv - roundi_str = self.chat_template.render(messages=roundi, add_generation_prompt=False) + roundi_str = self.chat_template.render( + messages=roundi, add_generation_prompt=False, **self.special_tokens_map + ) roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] roundi_no_ans_str = self.chat_template.render( - messages=roundi_no_ans, add_generation_prompt=add_generation_prompt + messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map ) ans_roundi = roundi_str[len(roundi_no_ans_str) :] ans.append(ans_roundi) @@ -794,7 +796,7 @@ def _encode_chat_inputs( regex_pattern = "|".join(map(re.escape, ans)) non_learnable_parts = re.split( r"(?:%s)" % regex_pattern, - self.chat_template.render(messages=origin_msg, add_generation_prompt=False), + self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), ) if non_learnable_parts[-1] == "": non_learnable_parts.pop() From 2124df7920f014ad88cd57af8a0756fb867a511f Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Tue, 16 Apr 2024 14:16:38 +0800 Subject: [PATCH 14/20] update save --- paddlenlp/transformers/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index b96d50256e70..cc1dbe112691 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -873,7 +873,7 @@ def init_chat_template(self, chat_template: str | dict): def save_resources(self, save_directory): super().save_resources(save_directory) - if self.chat_template is not None: + if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME) with open(chat_template_file, "w", encoding="utf-8") as f: json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4) From 5993027c5a2252ec6e015b2a4ad3705d2d8f6ad3 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 17:14:54 +0800 Subject: [PATCH 15/20] add multi round chat ut --- tests/transformers/test_chat_template.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 5effff95f383..29ba8ec83b0f 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -354,3 +354,27 @@ def test_train_format(self): ) self.assertEqual(tgt_id[tgt_idx - 1], -100) self.assertNotEqual(tgt_id[tgt_idx], -100) + + def test_train_format_multi(self): + from data import tokenize_rounds_example + + fake_data_args = self.DataArg(50, src_length=50) + example = {"src": ["用户Round 1", "用户Round 2"], "tgt": ["回答Round 1", "回答Round 2"]} + result, tgt_id = tokenize_rounds_example(self.tokenizer, example, fake_data_args, add_generation_prompt=True) + + tgt_idx_1 = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + tgt_idx_2 = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n" + "回答Round 1<|im_end|>\n<|im_start|>user\n用户Round 2<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + + self.assertEqual(tgt_id[tgt_idx_1 - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx_1], -100) + self.assertEqual(tgt_id[tgt_idx_2 - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx_2], -100) From 533d2ccdd6531b72c3ba2780974ab65c7ca417aa Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 17:15:32 +0800 Subject: [PATCH 16/20] use new api --- paddlenlp/transformers/tokenizer_utils.py | 72 +++++++++++------------ paddlenlp/transformers/utils.py | 18 ++++++ 2 files changed, 53 insertions(+), 37 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index cc1dbe112691..58e5a5b19bf3 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -59,7 +59,7 @@ TextInputPair, TruncationStrategy, ) -from .utils import InitTrackerMeta, fn_args_to_dict +from .utils import InitTrackerMeta, convert_to_dict_chat, fn_args_to_dict __all__ = [ "PretrainedTokenizer", @@ -642,7 +642,7 @@ def apply_chat_template( raise ValueError("chat_template is not set, please set chat_template first.") elif isinstance(self.chat_template, Template): add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True) - query = self._apply_chat_template(conversation, context_data, add_generation_prompt=add_generation_prompt) + query = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt) elif isinstance(self.chat_template, ChatTemplate): query = self._apply_chat_template_paddle(conversation, context_data) @@ -674,34 +674,22 @@ def _apply_chat_template_paddle( def _apply_chat_template( self, conversation: List[List[str, str] | Dict[str, str]] | str, - context_data: Dict[str, Any] = {}, add_generation_prompt=True, ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: if isinstance(conversation, str): conversations = [{"role": "user", "content": conversation}] elif isinstance(conversation, list): - conversations = [] - for index, item in enumerate(conversation): - if isinstance(item, dict): - conversations = conversation - break - elif isinstance(item, list): - assert 1 <= len(item) <= 2 - if isinstance(item[0], str): - conversations.append({"role": "user", "content": item[0]}) - if len(item) == 2 and isinstance(item[1], str): - conversations.append({"role": "assistant", "content": item[1]}) - else: - # item里只有一个元素,说明为最后一轮 - if index != len(conversation) - 1: - raise ValueError(f"Round {index} has error round") - else: - raise ValueError("Each round in list should be string") - else: - raise ValueError( - "apply_chat_template do not support appling batch conversations, " - "so you should apply the conversation one by one." - ) + assert len(conversation) > 0, "empty conversation is not allowed" + if isinstance(conversation[0], list): + conversations = convert_to_dict_chat(conversation) + breakpoint() + elif isinstance(conversation[0], dict): + conversations = conversation + else: + raise ValueError( + "apply_chat_template do not support appling batch conversations, " + "so you should apply the conversation one by one." + ) query = self.chat_template.render( messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt ) @@ -761,15 +749,15 @@ def _encode_chat_inputs( add_generation_prompt=True, ): result = {} - # conversation = [] - # if origin_msg[0]['role'] == 'system': - # system = origin_msg.pop(0) + + # Some template do not support system msg, so we need to check it first. if system: try: self.chat_template.render(messages={"role": "system", "content": system}) except Exception as e: raise ValueError("System is not supported in this tokenizer.", e) + # convert list msg to role dict msg conversation_dict = [] origin_msg = [] for round in conversations: @@ -781,6 +769,8 @@ def _encode_chat_inputs( conversation_dict.append(round_role) ans = [] + # get answer in single round, then compile the chat entirely and split by single round ans + # https://ku.baidu-int.com/knowledge/HFVrC7hq1Q/yKeL8Lljko/YkH5mORwJ3/aeec5d5a3eb84c for conv in conversation_dict: roundi = [system] + conv if system else conv roundi_str = self.chat_template.render( @@ -793,14 +783,7 @@ def _encode_chat_inputs( ans_roundi = roundi_str[len(roundi_no_ans_str) :] ans.append(ans_roundi) - regex_pattern = "|".join(map(re.escape, ans)) - non_learnable_parts = re.split( - r"(?:%s)" % regex_pattern, - self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), - ) - if non_learnable_parts[-1] == "": - non_learnable_parts.pop() - + non_learnable_parts = self._splited_by_specified_words(origin_msg, ans) assert len(non_learnable_parts) == len(ans) conversation_ids = [] @@ -812,11 +795,23 @@ def _encode_chat_inputs( padding=False, )["input_ids"] ) - # print(self.batch_decode(conversation_ids[i])) result["conversations"] = conversation_ids return result + def _splited_by_specified_words(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + """Split the entire chat by specified words.""" + # distingish and replace the special words in original string to an uncompiled form: Like | -> \| + regex_pattern = "|".join(map(re.escape, split_s)) + # splited by replaced specified words + non_learnable_parts = re.split( + r"(?:%s)" % regex_pattern, + self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), + ) + if non_learnable_parts[-1] == "": + non_learnable_parts.pop() + return non_learnable_parts + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): cache_dir = kwargs.pop("cache_dir", None) @@ -842,6 +837,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): logger.warning( "Chat-template already exists in config file, it will be overwritten by chat_template.json file." ) + logger.warning( + "`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`." + ) tokenizer.init_chat_template(chat_template_file) return tokenizer diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index c4919b1bdebb..cc35d274eb73 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -89,6 +89,24 @@ def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray return np_array.astype(target_dtype) +def convert_to_dict_chat(conversation: List[List[str]]): + """Convert the list of chat messages to a role dictionary chat messages.""" + conversations = [] + for index, item in enumerate(conversation): + assert 1 <= len(item) <= 2, "Each Rounds in conversation should have 1 or 2 elements." + if isinstance(item[0], str): + conversations.append({"role": "user", "content": item[0]}) + if len(item) == 2 and isinstance(item[1], str): + conversations.append({"role": "assistant", "content": item[1]}) + else: + # item里只有一个元素,说明为最后一轮 + if index != len(conversation) - 1: + raise ValueError(f"Round {index} has error round") + else: + raise ValueError("Each round in list should be string") + return conversations + + def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float: """get scale value by dtype From f477f618271e8ed426a3aa76a17ad86dd9a61141 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 17:17:09 +0800 Subject: [PATCH 17/20] rm pdb --- paddlenlp/transformers/tokenizer_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 58e5a5b19bf3..98876e4891fe 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -682,7 +682,6 @@ def _apply_chat_template( assert len(conversation) > 0, "empty conversation is not allowed" if isinstance(conversation[0], list): conversations = convert_to_dict_chat(conversation) - breakpoint() elif isinstance(conversation[0], dict): conversations = conversation else: From 574527987f0691f7f24f81010adddda0f1ea79dc Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 17:27:04 +0800 Subject: [PATCH 18/20] add split ut --- tests/transformers/test_chat_template.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 29ba8ec83b0f..90fab88861b9 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -378,3 +378,18 @@ def test_train_format_multi(self): self.assertNotEqual(tgt_id[tgt_idx_1], -100) self.assertEqual(tgt_id[tgt_idx_2 - 1], -100) self.assertNotEqual(tgt_id[tgt_idx_2], -100) + + def test_split_answer(self): + original_msg = [ + {"role": "user", "content": "用户Round 1"}, + {"role": "assistant", "content": "|回答Round 1|"}, + {"role": "user", "content": "用户Round 2"}, + {"role": "assistant", "content": "_回答Round 2?"}, + ] + answer = ["|回答Round 1|<|im_end|>\n", "_回答Round 2?<|im_end|>\n"] + split_part = self.tokenizer._splited_by_specified_words(original_msg, answer) + self.assertEqual(len(split_part), 2) + self.assertEqual( + split_part[0], + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n", + ) From b9e356074261202cc2ac9032357a241f107ae1e5 Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 17:45:43 +0800 Subject: [PATCH 19/20] fix des --- paddlenlp/transformers/tokenizer_utils.py | 12 ++++++------ paddlenlp/transformers/utils.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 98876e4891fe..e5a4b68c4c23 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -59,7 +59,7 @@ TextInputPair, TruncationStrategy, ) -from .utils import InitTrackerMeta, convert_to_dict_chat, fn_args_to_dict +from .utils import InitTrackerMeta, convert_to_dict_message, fn_args_to_dict __all__ = [ "PretrainedTokenizer", @@ -681,7 +681,7 @@ def _apply_chat_template( elif isinstance(conversation, list): assert len(conversation) > 0, "empty conversation is not allowed" if isinstance(conversation[0], list): - conversations = convert_to_dict_chat(conversation) + conversations = convert_to_dict_message(conversation) elif isinstance(conversation[0], dict): conversations = conversation else: @@ -769,7 +769,7 @@ def _encode_chat_inputs( ans = [] # get answer in single round, then compile the chat entirely and split by single round ans - # https://ku.baidu-int.com/knowledge/HFVrC7hq1Q/yKeL8Lljko/YkH5mORwJ3/aeec5d5a3eb84c + # attention: answer should include end token! for conv in conversation_dict: roundi = [system] + conv if system else conv roundi_str = self.chat_template.render( @@ -782,7 +782,7 @@ def _encode_chat_inputs( ans_roundi = roundi_str[len(roundi_no_ans_str) :] ans.append(ans_roundi) - non_learnable_parts = self._splited_by_specified_words(origin_msg, ans) + non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans) assert len(non_learnable_parts) == len(ans) conversation_ids = [] @@ -798,8 +798,8 @@ def _encode_chat_inputs( result["conversations"] = conversation_ids return result - def _splited_by_specified_words(self, origin_msg: List[Dict[str, str]], split_s: List[str]): - """Split the entire chat by specified words.""" + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + """Split the entire chat by specified words. Extract the non-learnable parts.""" # distingish and replace the special words in original string to an uncompiled form: Like | -> \| regex_pattern = "|".join(map(re.escape, split_s)) # splited by replaced specified words diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index cc35d274eb73..f785a5358af4 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -89,7 +89,7 @@ def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray return np_array.astype(target_dtype) -def convert_to_dict_chat(conversation: List[List[str]]): +def convert_to_dict_message(conversation: List[List[str]]): """Convert the list of chat messages to a role dictionary chat messages.""" conversations = [] for index, item in enumerate(conversation): @@ -99,7 +99,8 @@ def convert_to_dict_chat(conversation: List[List[str]]): if len(item) == 2 and isinstance(item[1], str): conversations.append({"role": "assistant", "content": item[1]}) else: - # item里只有一个元素,说明为最后一轮 + # If there is only one element in item, it must be the last round. + # If it is not the last round, it must be an error. if index != len(conversation) - 1: raise ValueError(f"Round {index} has error round") else: From 7d60cffd77575eaf1227c1fe7fab9cf90e9c9e8e Mon Sep 17 00:00:00 2001 From: Southpika <513923576@qq.com> Date: Sun, 28 Apr 2024 18:15:46 +0800 Subject: [PATCH 20/20] fix ut --- tests/transformers/test_chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 90fab88861b9..4f2641793a97 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -387,7 +387,7 @@ def test_split_answer(self): {"role": "assistant", "content": "_回答Round 2?"}, ] answer = ["|回答Round 1|<|im_end|>\n", "_回答Round 2?<|im_end|>\n"] - split_part = self.tokenizer._splited_by_specified_words(original_msg, answer) + split_part = self.tokenizer._extract_non_learnable_parts(original_msg, answer) self.assertEqual(len(split_part), 2) self.assertEqual( split_part[0],