Skip to content

Commit 71cc404

Browse files
authored
[Tokenizer]Add Chat template (#8226)
* add jinja template * add in cfg * add system * add notes * fix apply chat template * add generation flag * update jinja ut * fix error * add syntax error check * fix syntax error * refresh codev * refresh codecov * update special token map in render * update save * add multi round chat ut * use new api * rm pdb * add split ut * fix des * fix ut
1 parent 7c3ab53 commit 71cc404

File tree

5 files changed

+283
-17
lines changed

5 files changed

+283
-17
lines changed

llm/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args):
9090
return tokenized_source, tokenized_target_input_ids
9191

9292

93-
def tokenize_rounds_example(tokenizer, example, data_args):
93+
def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
9494
"""tokenize multi-rounds examples with chat_template.json
9595
9696
Args:
@@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):
117117

118118
# 1. only tokenize input_ids
119119
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
120-
conversations, context_data=context_data
120+
conversations, context_data=context_data, **kwargs
121121
)
122122
system_ids = conversation_result.pop("system", []) or []
123123

paddlenlp/transformers/tokenizer_utils.py

Lines changed: 148 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import numpy as np
3232
import paddle
3333
import six
34-
from jinja2.exceptions import TemplateError
34+
from jinja2 import Template
35+
from jinja2.exceptions import TemplateError, TemplateSyntaxError
3536
from jinja2.sandbox import ImmutableSandboxedEnvironment
3637
from paddle.utils import try_import
3738

@@ -58,7 +59,7 @@
5859
TextInputPair,
5960
TruncationStrategy,
6061
)
61-
from .utils import InitTrackerMeta, fn_args_to_dict
62+
from .utils import InitTrackerMeta, convert_to_dict_message, fn_args_to_dict
6263

6364
__all__ = [
6465
"PretrainedTokenizer",
@@ -516,7 +517,7 @@ class ChatTemplate:
516517

517518
@staticmethod
518519
@lru_cache()
519-
def _compile_jinja_template(chat_template):
520+
def _compile_jinja_template(chat_template) -> Template:
520521
def raise_exception(message):
521522
raise TemplateError(message)
522523

@@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str,
598599
raise ValueError(
599600
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]"
600601
)
601-
602602
if len(conversations[-1]) > 1:
603603
logger.warning(
604604
f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}"
@@ -623,7 +623,7 @@ class ChatTemplateMixin:
623623

624624
def apply_chat_template(
625625
self,
626-
conversation: List[List[str, str]] | str,
626+
conversation: List[List[str, str] | Dict[str, str]] | str,
627627
tokenize: bool = True,
628628
context_data: Dict[str, Any] = {},
629629
**tokenizer_kwargs
@@ -638,6 +638,26 @@ def apply_chat_template(
638638
Returns:
639639
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data
640640
"""
641+
if not self.chat_template:
642+
raise ValueError("chat_template is not set, please set chat_template first.")
643+
elif isinstance(self.chat_template, Template):
644+
add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True)
645+
query = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt)
646+
elif isinstance(self.chat_template, ChatTemplate):
647+
query = self._apply_chat_template_paddle(conversation, context_data)
648+
649+
if not tokenize:
650+
return query
651+
652+
# chat_template should not add special tokens
653+
tokenizer_kwargs["add_special_tokens"] = False
654+
return self(query, **tokenizer_kwargs)
655+
656+
def _apply_chat_template_paddle(
657+
self,
658+
conversation: List[List[str, str]] | str,
659+
context_data: Dict[str, Any] = {},
660+
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
641661
context_data = self.chat_template._init_context_data(context_data)
642662

643663
if isinstance(conversation, str):
@@ -649,14 +669,32 @@ def apply_chat_template(
649669
)
650670

651671
query = self.chat_template(conversation, context_data=context_data)
652-
if not tokenize:
653-
return query
672+
return query
654673

655-
# chat_template should not add special tokens
656-
tokenizer_kwargs["add_special_tokens"] = False
657-
return self(query, **tokenizer_kwargs)
674+
def _apply_chat_template(
675+
self,
676+
conversation: List[List[str, str] | Dict[str, str]] | str,
677+
add_generation_prompt=True,
678+
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
679+
if isinstance(conversation, str):
680+
conversations = [{"role": "user", "content": conversation}]
681+
elif isinstance(conversation, list):
682+
assert len(conversation) > 0, "empty conversation is not allowed"
683+
if isinstance(conversation[0], list):
684+
conversations = convert_to_dict_message(conversation)
685+
elif isinstance(conversation[0], dict):
686+
conversations = conversation
687+
else:
688+
raise ValueError(
689+
"apply_chat_template do not support appling batch conversations, "
690+
"so you should apply the conversation one by one."
691+
)
692+
query = self.chat_template.render(
693+
messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt
694+
)
695+
return query
658696

659-
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
697+
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs):
660698
"""Encodes conversation to pairs of token ids.
661699
Turn 0: bos + system + sep + user bot + eos
662700
Turn t: sep + bot + query bot + eos
@@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
668706
Returns:
669707
List[list[int], list[int]]: the pair of input_ids and target_ids
670708
"""
709+
if not self.chat_template:
710+
raise ValueError("chat_template is not set, please set chat_template first.")
711+
elif isinstance(self.chat_template, Template):
712+
add_generation_prompt = kwargs.pop("add_generation_prompt", True)
713+
query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt)
714+
elif isinstance(self.chat_template, ChatTemplate):
715+
query = self._encode_chat_inputs_paddle(conversations, context_data)
716+
return query
717+
718+
def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
671719
context_data = self.chat_template._init_context_data(context_data)
672720
# encode system
673721
result = {}
@@ -692,6 +740,77 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
692740
result["conversations"] = conversation_ids
693741
return result
694742

743+
def _encode_chat_inputs(
744+
self,
745+
conversations: List[List[str, str]],
746+
context_data: Dict[str, Any] = {},
747+
system: str = None,
748+
add_generation_prompt=True,
749+
):
750+
result = {}
751+
752+
# Some template do not support system msg, so we need to check it first.
753+
if system:
754+
try:
755+
self.chat_template.render(messages={"role": "system", "content": system})
756+
except Exception as e:
757+
raise ValueError("System is not supported in this tokenizer.", e)
758+
759+
# convert list msg to role dict msg
760+
conversation_dict = []
761+
origin_msg = []
762+
for round in conversations:
763+
round_role = [
764+
{"role": "user", "content": round[0]},
765+
{"role": "assistant", "content": round[1]},
766+
]
767+
origin_msg.extend(round_role)
768+
conversation_dict.append(round_role)
769+
ans = []
770+
771+
# get answer in single round, then compile the chat entirely and split by single round ans
772+
# attention: answer should include end token!
773+
for conv in conversation_dict:
774+
roundi = [system] + conv if system else conv
775+
roundi_str = self.chat_template.render(
776+
messages=roundi, add_generation_prompt=False, **self.special_tokens_map
777+
)
778+
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]]
779+
roundi_no_ans_str = self.chat_template.render(
780+
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
781+
)
782+
ans_roundi = roundi_str[len(roundi_no_ans_str) :]
783+
ans.append(ans_roundi)
784+
785+
non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
786+
assert len(non_learnable_parts) == len(ans)
787+
788+
conversation_ids = []
789+
for i in range(len(non_learnable_parts)):
790+
conversation_ids.append(
791+
self.batch_encode(
792+
[non_learnable_parts[i], ans[i]],
793+
add_special_tokens=False,
794+
padding=False,
795+
)["input_ids"]
796+
)
797+
798+
result["conversations"] = conversation_ids
799+
return result
800+
801+
def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
802+
"""Split the entire chat by specified words. Extract the non-learnable parts."""
803+
# distingish and replace the special words in original string to an uncompiled form: Like | -> \|
804+
regex_pattern = "|".join(map(re.escape, split_s))
805+
# splited by replaced specified words
806+
non_learnable_parts = re.split(
807+
r"(?:%s)" % regex_pattern,
808+
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map),
809+
)
810+
if non_learnable_parts[-1] == "":
811+
non_learnable_parts.pop()
812+
return non_learnable_parts
813+
695814
@classmethod
696815
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
697816
cache_dir = kwargs.pop("cache_dir", None)
@@ -713,6 +832,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
713832
if not os.path.exists(chat_template_file):
714833
return tokenizer
715834

835+
if tokenizer.chat_template is not None:
836+
logger.warning(
837+
"Chat-template already exists in config file, it will be overwritten by chat_template.json file."
838+
)
839+
logger.warning(
840+
"`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`."
841+
)
716842
tokenizer.init_chat_template(chat_template_file)
717843
return tokenizer
718844

@@ -724,9 +850,16 @@ def init_chat_template(self, chat_template: str | dict):
724850
"""
725851
if isinstance(chat_template, str):
726852
if not os.path.exists(chat_template):
727-
raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template))
728-
729-
self.chat_template = ChatTemplate.from_file(chat_template)
853+
try:
854+
self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template)
855+
except TemplateSyntaxError:
856+
# It is neither jinjia string nor path string
857+
raise TemplateSyntaxError(
858+
"The chat-template in json is not valid jinja string: {}".format(chat_template),
859+
lineno=0, # fake lineno, useless required msg
860+
)
861+
else:
862+
self.chat_template = ChatTemplate.from_file(chat_template)
730863
elif isinstance(chat_template, dict):
731864
self.chat_template = ChatTemplate.from_dict(chat_template)
732865
elif isinstance(chat_template, ChatTemplate):
@@ -737,7 +870,7 @@ def init_chat_template(self, chat_template: str | dict):
737870
def save_resources(self, save_directory):
738871
super().save_resources(save_directory)
739872

740-
if self.chat_template is not None:
873+
if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated
741874
chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME)
742875
with open(chat_template_file, "w", encoding="utf-8") as f:
743876
json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4)

paddlenlp/transformers/tokenizer_utils_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,9 @@ def convert_added_tokens(obj):
15751575

15761576
# TODO(guosheng): avoid reduplication of position args and key word args
15771577
tokenizer = cls(*init_args, **init_kwargs)
1578+
chat_template = init_kwargs.pop("chat_template", None)
1579+
if chat_template is not None:
1580+
tokenizer.init_chat_template(chat_template)
15781581
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
15791582
if special_tokens_map_file is not None:
15801583
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:

paddlenlp/transformers/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,25 @@ def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray
8989
return np_array.astype(target_dtype)
9090

9191

92+
def convert_to_dict_message(conversation: List[List[str]]):
93+
"""Convert the list of chat messages to a role dictionary chat messages."""
94+
conversations = []
95+
for index, item in enumerate(conversation):
96+
assert 1 <= len(item) <= 2, "Each Rounds in conversation should have 1 or 2 elements."
97+
if isinstance(item[0], str):
98+
conversations.append({"role": "user", "content": item[0]})
99+
if len(item) == 2 and isinstance(item[1], str):
100+
conversations.append({"role": "assistant", "content": item[1]})
101+
else:
102+
# If there is only one element in item, it must be the last round.
103+
# If it is not the last round, it must be an error.
104+
if index != len(conversation) - 1:
105+
raise ValueError(f"Round {index} has error round")
106+
else:
107+
raise ValueError("Each round in list should be string")
108+
return conversations
109+
110+
92111
def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
93112
"""get scale value by dtype
94113

0 commit comments

Comments
 (0)