-
Notifications
You must be signed in to change notification settings - Fork 3k
[Tokenizer]Add Chat template #8226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0e6af02
518b9ed
e50325f
c6ac050
f1bc935
89a5188
f35e5b3
c0cfbc7
167057b
d98369e
263a888
9677f25
393f87f
2124df7
5993027
533d2cc
f477f61
5745279
b9e3560
421cfc5
7d60cff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,8 @@ | |
import numpy as np | ||
import paddle | ||
import six | ||
from jinja2.exceptions import TemplateError | ||
from jinja2 import Template | ||
from jinja2.exceptions import TemplateError, TemplateSyntaxError | ||
from jinja2.sandbox import ImmutableSandboxedEnvironment | ||
from paddle.utils import try_import | ||
|
||
|
@@ -58,7 +59,7 @@ | |
TextInputPair, | ||
TruncationStrategy, | ||
) | ||
from .utils import InitTrackerMeta, fn_args_to_dict | ||
from .utils import InitTrackerMeta, convert_to_dict_message, fn_args_to_dict | ||
|
||
__all__ = [ | ||
"PretrainedTokenizer", | ||
|
@@ -516,7 +517,7 @@ class ChatTemplate: | |
|
||
@staticmethod | ||
@lru_cache() | ||
def _compile_jinja_template(chat_template): | ||
def _compile_jinja_template(chat_template) -> Template: | ||
def raise_exception(message): | ||
raise TemplateError(message) | ||
|
||
|
@@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str, | |
raise ValueError( | ||
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]" | ||
) | ||
|
||
if len(conversations[-1]) > 1: | ||
logger.warning( | ||
f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}" | ||
|
@@ -623,7 +623,7 @@ class ChatTemplateMixin: | |
|
||
def apply_chat_template( | ||
self, | ||
conversation: List[List[str, str]] | str, | ||
conversation: List[List[str, str] | Dict[str, str]] | str, | ||
tokenize: bool = True, | ||
context_data: Dict[str, Any] = {}, | ||
**tokenizer_kwargs | ||
|
@@ -638,6 +638,26 @@ def apply_chat_template( | |
Returns: | ||
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data | ||
""" | ||
if not self.chat_template: | ||
raise ValueError("chat_template is not set, please set chat_template first.") | ||
elif isinstance(self.chat_template, Template): | ||
add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True) | ||
query = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt) | ||
elif isinstance(self.chat_template, ChatTemplate): | ||
query = self._apply_chat_template_paddle(conversation, context_data) | ||
|
||
if not tokenize: | ||
return query | ||
|
||
# chat_template should not add special tokens | ||
tokenizer_kwargs["add_special_tokens"] = False | ||
return self(query, **tokenizer_kwargs) | ||
|
||
def _apply_chat_template_paddle( | ||
self, | ||
conversation: List[List[str, str]] | str, | ||
context_data: Dict[str, Any] = {}, | ||
) -> str | dict[str, numpy.ndarray | paddle.Tensor]: | ||
context_data = self.chat_template._init_context_data(context_data) | ||
|
||
if isinstance(conversation, str): | ||
|
@@ -649,14 +669,32 @@ def apply_chat_template( | |
) | ||
|
||
query = self.chat_template(conversation, context_data=context_data) | ||
if not tokenize: | ||
return query | ||
return query | ||
|
||
# chat_template should not add special tokens | ||
tokenizer_kwargs["add_special_tokens"] = False | ||
return self(query, **tokenizer_kwargs) | ||
def _apply_chat_template( | ||
self, | ||
conversation: List[List[str, str] | Dict[str, str]] | str, | ||
add_generation_prompt=True, | ||
) -> str | dict[str, numpy.ndarray | paddle.Tensor]: | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(conversation, str): | ||
conversations = [{"role": "user", "content": conversation}] | ||
elif isinstance(conversation, list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此外,也测试过新旧 chat_template 在 Predictor 中的使用是否符合预期,同时还要测试一下 gradio_ui 能够使用新旧 chat_template。 |
||
assert len(conversation) > 0, "empty conversation is not allowed" | ||
if isinstance(conversation[0], list): | ||
conversations = convert_to_dict_message(conversation) | ||
elif isinstance(conversation[0], dict): | ||
conversations = conversation | ||
else: | ||
raise ValueError( | ||
"apply_chat_template do not support appling batch conversations, " | ||
"so you should apply the conversation one by one." | ||
) | ||
query = self.chat_template.render( | ||
messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt | ||
) | ||
return query | ||
|
||
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): | ||
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs): | ||
"""Encodes conversation to pairs of token ids. | ||
Turn 0: bos + system + sep + user bot + eos | ||
Turn t: sep + bot + query bot + eos | ||
|
@@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: | |
Returns: | ||
List[list[int], list[int]]: the pair of input_ids and target_ids | ||
""" | ||
if not self.chat_template: | ||
raise ValueError("chat_template is not set, please set chat_template first.") | ||
elif isinstance(self.chat_template, Template): | ||
add_generation_prompt = kwargs.pop("add_generation_prompt", True) | ||
query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt) | ||
elif isinstance(self.chat_template, ChatTemplate): | ||
query = self._encode_chat_inputs_paddle(conversations, context_data) | ||
return query | ||
|
||
def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): | ||
context_data = self.chat_template._init_context_data(context_data) | ||
# encode system | ||
result = {} | ||
|
@@ -692,6 +740,77 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: | |
result["conversations"] = conversation_ids | ||
return result | ||
|
||
def _encode_chat_inputs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果脱离了之前设计的训推一体的 ChatTemplate,这个函数的适用性应该还挺低的,根本用不了。 所以,不太建议将 encode_chat_inputs 这块逻辑写到 tokenizer 里面去,尽量写到前处理里面去。 所以,这块的调整可能 就比较大了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 考虑到目前 |
||
self, | ||
conversations: List[List[str, str]], | ||
context_data: Dict[str, Any] = {}, | ||
system: str = None, | ||
add_generation_prompt=True, | ||
): | ||
result = {} | ||
|
||
# Some template do not support system msg, so we need to check it first. | ||
if system: | ||
try: | ||
self.chat_template.render(messages={"role": "system", "content": system}) | ||
except Exception as e: | ||
raise ValueError("System is not supported in this tokenizer.", e) | ||
|
||
# convert list msg to role dict msg | ||
conversation_dict = [] | ||
origin_msg = [] | ||
for round in conversations: | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
round_role = [ | ||
{"role": "user", "content": round[0]}, | ||
{"role": "assistant", "content": round[1]}, | ||
] | ||
origin_msg.extend(round_role) | ||
conversation_dict.append(round_role) | ||
ans = [] | ||
|
||
# get answer in single round, then compile the chat entirely and split by single round ans | ||
# attention: answer should include end token! | ||
for conv in conversation_dict: | ||
roundi = [system] + conv if system else conv | ||
roundi_str = self.chat_template.render( | ||
messages=roundi, add_generation_prompt=False, **self.special_tokens_map | ||
) | ||
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] | ||
roundi_no_ans_str = self.chat_template.render( | ||
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map | ||
) | ||
ans_roundi = roundi_str[len(roundi_no_ans_str) :] | ||
ans.append(ans_roundi) | ||
|
||
non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans) | ||
assert len(non_learnable_parts) == len(ans) | ||
|
||
conversation_ids = [] | ||
for i in range(len(non_learnable_parts)): | ||
conversation_ids.append( | ||
self.batch_encode( | ||
[non_learnable_parts[i], ans[i]], | ||
add_special_tokens=False, | ||
padding=False, | ||
)["input_ids"] | ||
) | ||
|
||
result["conversations"] = conversation_ids | ||
return result | ||
|
||
def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): | ||
"""Split the entire chat by specified words. Extract the non-learnable parts.""" | ||
# distingish and replace the special words in original string to an uncompiled form: Like | -> \| | ||
regex_pattern = "|".join(map(re.escape, split_s)) | ||
# splited by replaced specified words | ||
non_learnable_parts = re.split( | ||
r"(?:%s)" % regex_pattern, | ||
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), | ||
) | ||
if non_learnable_parts[-1] == "": | ||
non_learnable_parts.pop() | ||
return non_learnable_parts | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | ||
cache_dir = kwargs.pop("cache_dir", None) | ||
|
@@ -713,6 +832,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
if not os.path.exists(chat_template_file): | ||
return tokenizer | ||
|
||
if tokenizer.chat_template is not None: | ||
logger.warning( | ||
"Chat-template already exists in config file, it will be overwritten by chat_template.json file." | ||
) | ||
logger.warning( | ||
"`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`." | ||
) | ||
tokenizer.init_chat_template(chat_template_file) | ||
return tokenizer | ||
|
||
|
@@ -724,9 +850,16 @@ def init_chat_template(self, chat_template: str | dict): | |
""" | ||
if isinstance(chat_template, str): | ||
if not os.path.exists(chat_template): | ||
raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template)) | ||
|
||
self.chat_template = ChatTemplate.from_file(chat_template) | ||
try: | ||
self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template) | ||
except TemplateSyntaxError: | ||
# It is neither jinjia string nor path string | ||
raise TemplateSyntaxError( | ||
"The chat-template in json is not valid jinja string: {}".format(chat_template), | ||
lineno=0, # fake lineno, useless required msg | ||
) | ||
else: | ||
self.chat_template = ChatTemplate.from_file(chat_template) | ||
elif isinstance(chat_template, dict): | ||
self.chat_template = ChatTemplate.from_dict(chat_template) | ||
elif isinstance(chat_template, ChatTemplate): | ||
|
@@ -737,7 +870,7 @@ def init_chat_template(self, chat_template: str | dict): | |
def save_resources(self, save_directory): | ||
super().save_resources(save_directory) | ||
|
||
if self.chat_template is not None: | ||
if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated | ||
chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME) | ||
with open(chat_template_file, "w", encoding="utf-8") as f: | ||
json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4) | ||
|
Uh oh!
There was an error while loading. Please reload this page.