-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Tokenizer]Add Chat template #8226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
0e6af02
518b9ed
e50325f
c6ac050
f1bc935
89a5188
f35e5b3
c0cfbc7
167057b
d98369e
263a888
9677f25
393f87f
2124df7
5993027
533d2cc
f477f61
5745279
b9e3560
421cfc5
7d60cff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,8 @@ | |
import numpy as np | ||
import paddle | ||
import six | ||
from jinja2.exceptions import TemplateError | ||
from jinja2 import Template | ||
from jinja2.exceptions import TemplateError, TemplateSyntaxError | ||
from jinja2.sandbox import ImmutableSandboxedEnvironment | ||
from paddle.utils import try_import | ||
|
||
|
@@ -516,7 +517,7 @@ class ChatTemplate: | |
|
||
@staticmethod | ||
@lru_cache() | ||
def _compile_jinja_template(chat_template): | ||
def _compile_jinja_template(chat_template) -> Template: | ||
def raise_exception(message): | ||
raise TemplateError(message) | ||
|
||
|
@@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str, | |
raise ValueError( | ||
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]" | ||
) | ||
|
||
if len(conversations[-1]) > 1: | ||
logger.warning( | ||
f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}" | ||
|
@@ -623,7 +623,7 @@ class ChatTemplateMixin: | |
|
||
def apply_chat_template( | ||
self, | ||
conversation: List[List[str, str]] | str, | ||
conversation: List[List[str, str] | Dict[str, str]] | str, | ||
tokenize: bool = True, | ||
context_data: Dict[str, Any] = {}, | ||
**tokenizer_kwargs | ||
|
@@ -638,6 +638,26 @@ def apply_chat_template( | |
Returns: | ||
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data | ||
""" | ||
if not self.chat_template: | ||
raise ValueError("chat_template is not set, please set chat_template first.") | ||
elif isinstance(self.chat_template, Template): | ||
add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True) | ||
query = self._apply_chat_template(conversation, context_data, add_generation_prompt=add_generation_prompt) | ||
elif isinstance(self.chat_template, ChatTemplate): | ||
query = self._apply_chat_template_paddle(conversation, context_data) | ||
|
||
if not tokenize: | ||
return query | ||
|
||
# chat_template should not add special tokens | ||
tokenizer_kwargs["add_special_tokens"] = False | ||
return self(query, **tokenizer_kwargs) | ||
|
||
def _apply_chat_template_paddle( | ||
self, | ||
conversation: List[List[str, str]] | str, | ||
context_data: Dict[str, Any] = {}, | ||
) -> str | dict[str, numpy.ndarray | paddle.Tensor]: | ||
context_data = self.chat_template._init_context_data(context_data) | ||
|
||
if isinstance(conversation, str): | ||
|
@@ -649,14 +669,45 @@ def apply_chat_template( | |
) | ||
|
||
query = self.chat_template(conversation, context_data=context_data) | ||
if not tokenize: | ||
return query | ||
return query | ||
|
||
# chat_template should not add special tokens | ||
tokenizer_kwargs["add_special_tokens"] = False | ||
return self(query, **tokenizer_kwargs) | ||
def _apply_chat_template( | ||
self, | ||
conversation: List[List[str, str] | Dict[str, str]] | str, | ||
context_data: Dict[str, Any] = {}, | ||
add_generation_prompt=True, | ||
) -> str | dict[str, numpy.ndarray | paddle.Tensor]: | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(conversation, str): | ||
conversations = [{"role": "user", "content": conversation}] | ||
elif isinstance(conversation, list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此外,也测试过新旧 chat_template 在 Predictor 中的使用是否符合预期,同时还要测试一下 gradio_ui 能够使用新旧 chat_template。 |
||
conversations = [] | ||
for index, item in enumerate(conversation): | ||
if isinstance(item, dict): | ||
conversations = conversation | ||
break | ||
elif isinstance(item, list): | ||
assert 1 <= len(item) <= 2 | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(item[0], str): | ||
conversations.append({"role": "user", "content": item[0]}) | ||
if len(item) == 2 and isinstance(item[1], str): | ||
conversations.append({"role": "assistant", "content": item[1]}) | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# item里只有一个元素,说明为最后一轮 | ||
if index != len(conversation) - 1: | ||
raise ValueError(f"Round {index} has error round") | ||
else: | ||
raise ValueError("Each round in list should be string") | ||
else: | ||
raise ValueError( | ||
"apply_chat_template do not support appling batch conversations, " | ||
"so you should apply the conversation one by one." | ||
) | ||
query = self.chat_template.render( | ||
messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt | ||
) | ||
return query | ||
|
||
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): | ||
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs): | ||
"""Encodes conversation to pairs of token ids. | ||
Turn 0: bos + system + sep + user bot + eos | ||
Turn t: sep + bot + query bot + eos | ||
|
@@ -668,6 +719,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: | |
Returns: | ||
List[list[int], list[int]]: the pair of input_ids and target_ids | ||
""" | ||
if not self.chat_template: | ||
raise ValueError("chat_template is not set, please set chat_template first.") | ||
elif isinstance(self.chat_template, Template): | ||
add_generation_prompt = kwargs.pop("add_generation_prompt", True) | ||
query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt) | ||
elif isinstance(self.chat_template, ChatTemplate): | ||
query = self._encode_chat_inputs_paddle(conversations, context_data) | ||
return query | ||
|
||
def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}): | ||
context_data = self.chat_template._init_context_data(context_data) | ||
# encode system | ||
result = {} | ||
|
@@ -692,6 +753,70 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: | |
result["conversations"] = conversation_ids | ||
return result | ||
|
||
def _encode_chat_inputs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果脱离了之前设计的训推一体的 ChatTemplate,这个函数的适用性应该还挺低的,根本用不了。 所以,不太建议将 encode_chat_inputs 这块逻辑写到 tokenizer 里面去,尽量写到前处理里面去。 所以,这块的调整可能 就比较大了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 考虑到目前 |
||
self, | ||
conversations: List[List[str, str]], | ||
context_data: Dict[str, Any] = {}, | ||
system: str = None, | ||
add_generation_prompt=True, | ||
): | ||
result = {} | ||
# conversation = [] | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# if origin_msg[0]['role'] == 'system': | ||
# system = origin_msg.pop(0) | ||
if system: | ||
try: | ||
self.chat_template.render(messages={"role": "system", "content": system}) | ||
except Exception as e: | ||
raise ValueError("System is not supported in this tokenizer.", e) | ||
|
||
conversation_dict = [] | ||
origin_msg = [] | ||
for round in conversations: | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
round_role = [ | ||
{"role": "user", "content": round[0]}, | ||
{"role": "assistant", "content": round[1]}, | ||
] | ||
origin_msg.extend(round_role) | ||
conversation_dict.append(round_role) | ||
ans = [] | ||
|
||
for conv in conversation_dict: | ||
roundi = [system] + conv if system else conv | ||
roundi_str = self.chat_template.render( | ||
messages=roundi, add_generation_prompt=False, **self.special_tokens_map | ||
) | ||
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] | ||
roundi_no_ans_str = self.chat_template.render( | ||
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map | ||
) | ||
ans_roundi = roundi_str[len(roundi_no_ans_str) :] | ||
ans.append(ans_roundi) | ||
|
||
regex_pattern = "|".join(map(re.escape, ans)) | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
non_learnable_parts = re.split( | ||
r"(?:%s)" % regex_pattern, | ||
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), | ||
) | ||
if non_learnable_parts[-1] == "": | ||
non_learnable_parts.pop() | ||
|
||
assert len(non_learnable_parts) == len(ans) | ||
|
||
conversation_ids = [] | ||
for i in range(len(non_learnable_parts)): | ||
conversation_ids.append( | ||
self.batch_encode( | ||
[non_learnable_parts[i], ans[i]], | ||
add_special_tokens=False, | ||
padding=False, | ||
)["input_ids"] | ||
) | ||
# print(self.batch_decode(conversation_ids[i])) | ||
Southpika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result["conversations"] = conversation_ids | ||
return result | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | ||
cache_dir = kwargs.pop("cache_dir", None) | ||
|
@@ -713,6 +838,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
if not os.path.exists(chat_template_file): | ||
return tokenizer | ||
|
||
if tokenizer.chat_template is not None: | ||
logger.warning( | ||
"Chat-template already exists in config file, it will be overwritten by chat_template.json file." | ||
) | ||
tokenizer.init_chat_template(chat_template_file) | ||
return tokenizer | ||
|
||
|
@@ -724,9 +853,16 @@ def init_chat_template(self, chat_template: str | dict): | |
""" | ||
if isinstance(chat_template, str): | ||
if not os.path.exists(chat_template): | ||
raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template)) | ||
|
||
self.chat_template = ChatTemplate.from_file(chat_template) | ||
try: | ||
self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template) | ||
except TemplateSyntaxError: | ||
# It is neither jinjia string nor path string | ||
raise TemplateSyntaxError( | ||
"The chat-template in json is not valid jinja string: {}".format(chat_template), | ||
lineno=0, # fake lineno, useless required msg | ||
) | ||
else: | ||
self.chat_template = ChatTemplate.from_file(chat_template) | ||
elif isinstance(chat_template, dict): | ||
self.chat_template = ChatTemplate.from_dict(chat_template) | ||
elif isinstance(chat_template, ChatTemplate): | ||
|
@@ -737,7 +873,7 @@ def init_chat_template(self, chat_template: str | dict): | |
def save_resources(self, save_directory): | ||
super().save_resources(save_directory) | ||
|
||
if self.chat_template is not None: | ||
if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated | ||
chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME) | ||
with open(chat_template_file, "w", encoding="utf-8") as f: | ||
json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4) | ||
|
Uh oh!
There was an error while loading. Please reload this page.