Skip to content

[Tokenizer]Add Chat template #8226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args):
return tokenized_source, tokenized_target_input_ids


def tokenize_rounds_example(tokenizer, example, data_args):
def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
"""tokenize multi-rounds examples with chat_template.json

Args:
Expand All @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):

# 1. only tokenize input_ids
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
conversations, context_data=context_data
conversations, context_data=context_data, **kwargs
)
system_ids = conversation_result.pop("system", []) or []

Expand Down
163 changes: 148 additions & 15 deletions paddlenlp/transformers/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import numpy as np
import paddle
import six
from jinja2.exceptions import TemplateError
from jinja2 import Template
from jinja2.exceptions import TemplateError, TemplateSyntaxError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from paddle.utils import try_import

Expand All @@ -58,7 +59,7 @@
TextInputPair,
TruncationStrategy,
)
from .utils import InitTrackerMeta, fn_args_to_dict
from .utils import InitTrackerMeta, convert_to_dict_message, fn_args_to_dict

__all__ = [
"PretrainedTokenizer",
Expand Down Expand Up @@ -516,7 +517,7 @@ class ChatTemplate:

@staticmethod
@lru_cache()
def _compile_jinja_template(chat_template):
def _compile_jinja_template(chat_template) -> Template:
def raise_exception(message):
raise TemplateError(message)

Expand Down Expand Up @@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str,
raise ValueError(
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]"
)

if len(conversations[-1]) > 1:
logger.warning(
f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}"
Expand All @@ -623,7 +623,7 @@ class ChatTemplateMixin:

def apply_chat_template(
self,
conversation: List[List[str, str]] | str,
conversation: List[List[str, str] | Dict[str, str]] | str,
tokenize: bool = True,
context_data: Dict[str, Any] = {},
**tokenizer_kwargs
Expand All @@ -638,6 +638,26 @@ def apply_chat_template(
Returns:
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data
"""
if not self.chat_template:
raise ValueError("chat_template is not set, please set chat_template first.")
elif isinstance(self.chat_template, Template):
add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True)
query = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt)
elif isinstance(self.chat_template, ChatTemplate):
query = self._apply_chat_template_paddle(conversation, context_data)

if not tokenize:
return query

# chat_template should not add special tokens
tokenizer_kwargs["add_special_tokens"] = False
return self(query, **tokenizer_kwargs)

def _apply_chat_template_paddle(
self,
conversation: List[List[str, str]] | str,
context_data: Dict[str, Any] = {},
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
context_data = self.chat_template._init_context_data(context_data)

if isinstance(conversation, str):
Expand All @@ -649,14 +669,32 @@ def apply_chat_template(
)

query = self.chat_template(conversation, context_data=context_data)
if not tokenize:
return query
return query

# chat_template should not add special tokens
tokenizer_kwargs["add_special_tokens"] = False
return self(query, **tokenizer_kwargs)
def _apply_chat_template(
self,
conversation: List[List[str, str] | Dict[str, str]] | str,
add_generation_prompt=True,
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
if isinstance(conversation, str):
conversations = [{"role": "user", "content": conversation}]
elif isinstance(conversation, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此外,也测试过新旧 chat_template 在 Predictor 中的使用是否符合预期,同时还要测试一下 gradio_ui 能够使用新旧 chat_template。

assert len(conversation) > 0, "empty conversation is not allowed"
if isinstance(conversation[0], list):
conversations = convert_to_dict_message(conversation)
elif isinstance(conversation[0], dict):
conversations = conversation
else:
raise ValueError(
"apply_chat_template do not support appling batch conversations, "
"so you should apply the conversation one by one."
)
query = self.chat_template.render(
messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt
)
return query

def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs):
"""Encodes conversation to pairs of token ids.
Turn 0: bos + system + sep + user bot + eos
Turn t: sep + bot + query bot + eos
Expand All @@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
Returns:
List[list[int], list[int]]: the pair of input_ids and target_ids
"""
if not self.chat_template:
raise ValueError("chat_template is not set, please set chat_template first.")
elif isinstance(self.chat_template, Template):
add_generation_prompt = kwargs.pop("add_generation_prompt", True)
query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt)
elif isinstance(self.chat_template, ChatTemplate):
query = self._encode_chat_inputs_paddle(conversations, context_data)
return query

def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
context_data = self.chat_template._init_context_data(context_data)
# encode system
result = {}
Expand All @@ -692,6 +740,77 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
result["conversations"] = conversation_ids
return result

def _encode_chat_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果脱离了之前设计的训推一体的 ChatTemplate,这个函数的适用性应该还挺低的,根本用不了。

所以,不太建议将 encode_chat_inputs 这块逻辑写到 tokenizer 里面去,尽量写到前处理里面去。

所以,这块的调整可能 就比较大了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到目前encode_chat_inputs函数使用较广,移除造成影响范围可能较大。是否可以考虑以下策略:
默认tgt src切分方式为 src中不含有bot start token:即tgt中含有完整的user轮+bot start token
如果需要重写,则在tokenizer类中单独定义:如qwen

self,
conversations: List[List[str, str]],
context_data: Dict[str, Any] = {},
system: str = None,
add_generation_prompt=True,
):
result = {}

# Some template do not support system msg, so we need to check it first.
if system:
try:
self.chat_template.render(messages={"role": "system", "content": system})
except Exception as e:
raise ValueError("System is not supported in this tokenizer.", e)

# convert list msg to role dict msg
conversation_dict = []
origin_msg = []
for round in conversations:
round_role = [
{"role": "user", "content": round[0]},
{"role": "assistant", "content": round[1]},
]
origin_msg.extend(round_role)
conversation_dict.append(round_role)
ans = []

# get answer in single round, then compile the chat entirely and split by single round ans
# attention: answer should include end token!
for conv in conversation_dict:
roundi = [system] + conv if system else conv
roundi_str = self.chat_template.render(
messages=roundi, add_generation_prompt=False, **self.special_tokens_map
)
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]]
roundi_no_ans_str = self.chat_template.render(
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)
ans_roundi = roundi_str[len(roundi_no_ans_str) :]
ans.append(ans_roundi)

non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
assert len(non_learnable_parts) == len(ans)

conversation_ids = []
for i in range(len(non_learnable_parts)):
conversation_ids.append(
self.batch_encode(
[non_learnable_parts[i], ans[i]],
add_special_tokens=False,
padding=False,
)["input_ids"]
)

result["conversations"] = conversation_ids
return result

def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
"""Split the entire chat by specified words. Extract the non-learnable parts."""
# distingish and replace the special words in original string to an uncompiled form: Like | -> \|
regex_pattern = "|".join(map(re.escape, split_s))
# splited by replaced specified words
non_learnable_parts = re.split(
r"(?:%s)" % regex_pattern,
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map),
)
if non_learnable_parts[-1] == "":
non_learnable_parts.pop()
return non_learnable_parts

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
Expand All @@ -713,6 +832,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if not os.path.exists(chat_template_file):
return tokenizer

if tokenizer.chat_template is not None:
logger.warning(
"Chat-template already exists in config file, it will be overwritten by chat_template.json file."
)
logger.warning(
"`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`."
)
tokenizer.init_chat_template(chat_template_file)
return tokenizer

Expand All @@ -724,9 +850,16 @@ def init_chat_template(self, chat_template: str | dict):
"""
if isinstance(chat_template, str):
if not os.path.exists(chat_template):
raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template))

self.chat_template = ChatTemplate.from_file(chat_template)
try:
self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template)
except TemplateSyntaxError:
# It is neither jinjia string nor path string
raise TemplateSyntaxError(
"The chat-template in json is not valid jinja string: {}".format(chat_template),
lineno=0, # fake lineno, useless required msg
)
else:
self.chat_template = ChatTemplate.from_file(chat_template)
elif isinstance(chat_template, dict):
self.chat_template = ChatTemplate.from_dict(chat_template)
elif isinstance(chat_template, ChatTemplate):
Expand All @@ -737,7 +870,7 @@ def init_chat_template(self, chat_template: str | dict):
def save_resources(self, save_directory):
super().save_resources(save_directory)

if self.chat_template is not None:
if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated
chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME)
with open(chat_template_file, "w", encoding="utf-8") as f:
json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,9 @@ def convert_added_tokens(obj):

# TODO(guosheng): avoid reduplication of position args and key word args
tokenizer = cls(*init_args, **init_kwargs)
chat_template = init_kwargs.pop("chat_template", None)
if chat_template is not None:
tokenizer.init_chat_template(chat_template)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
Expand Down
19 changes: 19 additions & 0 deletions paddlenlp/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray
return np_array.astype(target_dtype)


def convert_to_dict_message(conversation: List[List[str]]):
"""Convert the list of chat messages to a role dictionary chat messages."""
conversations = []
for index, item in enumerate(conversation):
assert 1 <= len(item) <= 2, "Each Rounds in conversation should have 1 or 2 elements."
if isinstance(item[0], str):
conversations.append({"role": "user", "content": item[0]})
if len(item) == 2 and isinstance(item[1], str):
conversations.append({"role": "assistant", "content": item[1]})
else:
# If there is only one element in item, it must be the last round.
# If it is not the last round, it must be an error.
if index != len(conversation) - 1:
raise ValueError(f"Round {index} has error round")
else:
raise ValueError("Each round in list should be string")
return conversations


def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
"""get scale value by dtype

Expand Down
Loading