diff --git a/llm/data.py b/llm/data.py index 767bd5a88a29..a7b51264bcaa 100644 --- a/llm/data.py +++ b/llm/data.py @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args): return tokenized_source, tokenized_target_input_ids -def tokenize_rounds_example(tokenizer, example, data_args): +def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): """tokenize multi-rounds examples with chat_template.json Args: @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args): # 1. only tokenize input_ids conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs( - conversations, context_data=context_data + conversations, context_data=context_data, **kwargs ) system_ids = conversation_result.pop("system", []) or [] diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 3620669fefe6..e5a4b68c4c23 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -31,7 +31,8 @@ import numpy as np import paddle import six -from jinja2.exceptions import TemplateError +from jinja2 import Template +from jinja2.exceptions import TemplateError, TemplateSyntaxError from jinja2.sandbox import ImmutableSandboxedEnvironment from paddle.utils import try_import @@ -58,7 +59,7 @@ TextInputPair, TruncationStrategy, ) -from .utils import InitTrackerMeta, fn_args_to_dict +from .utils import InitTrackerMeta, convert_to_dict_message, fn_args_to_dict __all__ = [ "PretrainedTokenizer", @@ -516,7 +517,7 @@ class ChatTemplate: @staticmethod @lru_cache() - def _compile_jinja_template(chat_template): + def _compile_jinja_template(chat_template) -> Template: def raise_exception(message): raise TemplateError(message) @@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str, raise ValueError( "The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]" ) - if len(conversations[-1]) > 1: logger.warning( f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}" @@ -623,7 +623,7 @@ class ChatTemplateMixin: def apply_chat_template( self, - conversation: List[List[str, str]] | str, + conversation: List[List[str, str] | Dict[str, str]] | str, tokenize: bool = True, context_data: Dict[str, Any] = {}, **tokenizer_kwargs @@ -638,6 +638,26 @@ def apply_chat_template( Returns: str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data """ + if not self.chat_template: + raise ValueError("chat_template is not set, please set chat_template first.") + elif isinstance(self.chat_template, Template): + add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True) + query = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt) + elif isinstance(self.chat_template, ChatTemplate): + query = self._apply_chat_template_paddle(conversation, context_data) + + if not tokenize: + return query + + # chat_template should not add special tokens + tokenizer_kwargs["add_special_tokens"] = False + return self(query, **tokenizer_kwargs) + + def _apply_chat_template_paddle( + self, + conversation: List[List[str, str]] | str, + context_data: Dict[str, Any] = {}, + ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: context_data = self.chat_template._init_context_data(context_data) if isinstance(conversation, str): @@ -649,14 +669,32 @@ def apply_chat_template( ) query = self.chat_template(conversation, context_data=context_data) - if not tokenize: - return query + return query - # chat_template should not add special tokens - tokenizer_kwargs["add_special_tokens"] = False - return self(query, **tokenizer_kwargs) + def _apply_chat_template( + self, + conversation: List[List[str, str] | Dict[str, str]] | str, + add_generation_prompt=True, + ) -> str | dict[str, numpy.ndarray | paddle.Tensor]: + if isinstance(conversation, str): + conversations = [{"role": "user", "content": conversation}] + elif isinstance(conversation, list): + assert len(conversation) > 0, "empty conversation is not allowed" + if isinstance(conversation[0], list): + conversations = convert_to_dict_message(conversation) + elif isinstance(conversation[0], dict): + conversations = conversation + else: + raise ValueError( + "apply_chat_template do not support appling batch conversations, " + "so you should apply the conversation one by one." + ) + query = self.chat_template.render( + messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt + ) + return query - def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): + def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs): """Encodes conversation to pairs of token ids. Turn 0: bos + system + sep + user bot + eos Turn t: sep + bot + query bot + eos @@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Returns: List[list[int], list[int]]: the pair of input_ids and target_ids """ + if not self.chat_template: + raise ValueError("chat_template is not set, please set chat_template first.") + elif isinstance(self.chat_template, Template): + add_generation_prompt = kwargs.pop("add_generation_prompt", True) + query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt) + elif isinstance(self.chat_template, ChatTemplate): + query = self._encode_chat_inputs_paddle(conversations, context_data) + return query + + def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): context_data = self.chat_template._init_context_data(context_data) # encode system result = {} @@ -692,6 +740,77 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: result["conversations"] = conversation_ids return result + def _encode_chat_inputs( + self, + conversations: List[List[str, str]], + context_data: Dict[str, Any] = {}, + system: str = None, + add_generation_prompt=True, + ): + result = {} + + # Some template do not support system msg, so we need to check it first. + if system: + try: + self.chat_template.render(messages={"role": "system", "content": system}) + except Exception as e: + raise ValueError("System is not supported in this tokenizer.", e) + + # convert list msg to role dict msg + conversation_dict = [] + origin_msg = [] + for round in conversations: + round_role = [ + {"role": "user", "content": round[0]}, + {"role": "assistant", "content": round[1]}, + ] + origin_msg.extend(round_role) + conversation_dict.append(round_role) + ans = [] + + # get answer in single round, then compile the chat entirely and split by single round ans + # attention: answer should include end token! + for conv in conversation_dict: + roundi = [system] + conv if system else conv + roundi_str = self.chat_template.render( + messages=roundi, add_generation_prompt=False, **self.special_tokens_map + ) + roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] + roundi_no_ans_str = self.chat_template.render( + messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map + ) + ans_roundi = roundi_str[len(roundi_no_ans_str) :] + ans.append(ans_roundi) + + non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans) + assert len(non_learnable_parts) == len(ans) + + conversation_ids = [] + for i in range(len(non_learnable_parts)): + conversation_ids.append( + self.batch_encode( + [non_learnable_parts[i], ans[i]], + add_special_tokens=False, + padding=False, + )["input_ids"] + ) + + result["conversations"] = conversation_ids + return result + + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + """Split the entire chat by specified words. Extract the non-learnable parts.""" + # distingish and replace the special words in original string to an uncompiled form: Like | -> \| + regex_pattern = "|".join(map(re.escape, split_s)) + # splited by replaced specified words + non_learnable_parts = re.split( + r"(?:%s)" % regex_pattern, + self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), + ) + if non_learnable_parts[-1] == "": + non_learnable_parts.pop() + return non_learnable_parts + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): cache_dir = kwargs.pop("cache_dir", None) @@ -713,6 +832,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): if not os.path.exists(chat_template_file): return tokenizer + if tokenizer.chat_template is not None: + logger.warning( + "Chat-template already exists in config file, it will be overwritten by chat_template.json file." + ) + logger.warning( + "`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`." + ) tokenizer.init_chat_template(chat_template_file) return tokenizer @@ -724,9 +850,16 @@ def init_chat_template(self, chat_template: str | dict): """ if isinstance(chat_template, str): if not os.path.exists(chat_template): - raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template)) - - self.chat_template = ChatTemplate.from_file(chat_template) + try: + self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template) + except TemplateSyntaxError: + # It is neither jinjia string nor path string + raise TemplateSyntaxError( + "The chat-template in json is not valid jinja string: {}".format(chat_template), + lineno=0, # fake lineno, useless required msg + ) + else: + self.chat_template = ChatTemplate.from_file(chat_template) elif isinstance(chat_template, dict): self.chat_template = ChatTemplate.from_dict(chat_template) elif isinstance(chat_template, ChatTemplate): @@ -737,7 +870,7 @@ def init_chat_template(self, chat_template: str | dict): def save_resources(self, save_directory): super().save_resources(save_directory) - if self.chat_template is not None: + if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME) with open(chat_template_file, "w", encoding="utf-8") as f: json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4) diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index 77cf08a1b5d2..4f9da5ea7785 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -1575,6 +1575,9 @@ def convert_added_tokens(obj): # TODO(guosheng): avoid reduplication of position args and key word args tokenizer = cls(*init_args, **init_kwargs) + chat_template = init_kwargs.pop("chat_template", None) + if chat_template is not None: + tokenizer.init_chat_template(chat_template) special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) if special_tokens_map_file is not None: with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index c4919b1bdebb..f785a5358af4 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -89,6 +89,25 @@ def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray return np_array.astype(target_dtype) +def convert_to_dict_message(conversation: List[List[str]]): + """Convert the list of chat messages to a role dictionary chat messages.""" + conversations = [] + for index, item in enumerate(conversation): + assert 1 <= len(item) <= 2, "Each Rounds in conversation should have 1 or 2 elements." + if isinstance(item[0], str): + conversations.append({"role": "user", "content": item[0]}) + if len(item) == 2 and isinstance(item[1], str): + conversations.append({"role": "assistant", "content": item[1]}) + else: + # If there is only one element in item, it must be the last round. + # If it is not the last round, it must be an error. + if index != len(conversation) - 1: + raise ValueError(f"Round {index} has error round") + else: + raise ValueError("Each round in list should be string") + return conversations + + def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float: """get scale value by dtype diff --git a/tests/transformers/test_chat_template.py b/tests/transformers/test_chat_template.py index 0d7e2e68ae9f..4f2641793a97 100644 --- a/tests/transformers/test_chat_template.py +++ b/tests/transformers/test_chat_template.py @@ -282,3 +282,114 @@ def test_inference_template_with_context_data(self): final_query = tokenizer.apply_chat_template(query, context_data=context_data, tokenize=False) expected_query = "你是一个人工智能助手<>-<>\nHuman: 你好 Bot:" self.assertEqual(final_query, expected_query) + + +class TemplateIntegrationTest(unittest.TestCase): + class DataArg: + def __init__(self, max_length, src_length: Optional[int] = None): + self.max_length: int = max_length + if src_length is None: + src_length = self.max_length - 8 + + self.src_length: int = src_length + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("qwen/qwen-7b-chat") + qwen_jinja = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + self.tokenizer.init_chat_template(qwen_jinja) + return super().setUp() + + def test_chat_template(self): + # test single turn + query = "你好" + final_query = self.tokenizer.apply_chat_template(query, tokenize=False) + expected_query = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n" + self.assertEqual(final_query, expected_query) + + # test multi turns conversation + query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"]] + final_query = self.tokenizer.apply_chat_template(query, tokenize=False) + expected_query = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n<|im_start|>user\n今天吃啥<|im_end|>\n<|im_start|>assistant\n" + self.assertEqual(final_query, expected_query) + + def test_system_error(self): + # test system messaage error + error_jinja = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}" + self.tokenizer.init_chat_template(error_jinja) + from jinja2.exceptions import TemplateError + + with self.assertRaises(TemplateError): + self.tokenizer.apply_chat_template([{"role": "system", "content": ""}]) + + def test_round_error(self): + # error round, 1 is not a valid role. + query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"], ["你好", "您好"]] + with self.assertRaises(ValueError): + self.tokenizer.apply_chat_template(query, tokenize=False) + + def test_jinja_syntax_error(self): + # test system messaage error + error_jinja = ( + "{ bos_token }{% if messages[0]['role'] == 'system' %}{ raise_exception('System role not supported')}" + ) + from jinja2.exceptions import TemplateSyntaxError + + with self.assertRaises(TemplateSyntaxError): + self.tokenizer.init_chat_template(error_jinja) + + def test_train_format(self): + from data import tokenize_rounds_example + + fake_data_args = self.DataArg(50, src_length=50) + example = {"src": ["你好"], "tgt": ["您好,我是个人人工智能助手"]} + result, tgt_id = tokenize_rounds_example(self.tokenizer, example, fake_data_args, add_generation_prompt=True) + sentence_result = self.tokenizer.decode(result["input_ids"]) + expected_sentence = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n" + self.assertEqual(expected_sentence, sentence_result) + + tgt_idx = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + self.assertEqual(tgt_id[tgt_idx - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx], -100) + + def test_train_format_multi(self): + from data import tokenize_rounds_example + + fake_data_args = self.DataArg(50, src_length=50) + example = {"src": ["用户Round 1", "用户Round 2"], "tgt": ["回答Round 1", "回答Round 2"]} + result, tgt_id = tokenize_rounds_example(self.tokenizer, example, fake_data_args, add_generation_prompt=True) + + tgt_idx_1 = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + tgt_idx_2 = len( + self.tokenizer.encode( + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n" + "回答Round 1<|im_end|>\n<|im_start|>user\n用户Round 2<|im_end|>\n<|im_start|>assistant\n" + )["input_ids"] + ) + + self.assertEqual(tgt_id[tgt_idx_1 - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx_1], -100) + self.assertEqual(tgt_id[tgt_idx_2 - 1], -100) + self.assertNotEqual(tgt_id[tgt_idx_2], -100) + + def test_split_answer(self): + original_msg = [ + {"role": "user", "content": "用户Round 1"}, + {"role": "assistant", "content": "|回答Round 1|"}, + {"role": "user", "content": "用户Round 2"}, + {"role": "assistant", "content": "_回答Round 2?"}, + ] + answer = ["|回答Round 1|<|im_end|>\n", "_回答Round 2?<|im_end|>\n"] + split_part = self.tokenizer._extract_non_learnable_parts(original_msg, answer) + self.assertEqual(len(split_part), 2) + self.assertEqual( + split_part[0], + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n用户Round 1<|im_end|>\n<|im_start|>assistant\n", + )