Skip to content

[Tokenizer]Add Chat template #8226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args):
return tokenized_source, tokenized_target_input_ids


def tokenize_rounds_example(tokenizer, example, data_args):
def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
"""tokenize multi-rounds examples with chat_template.json

Args:
Expand All @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):

# 1. only tokenize input_ids
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
conversations, context_data=context_data
conversations, context_data=context_data, **kwargs
)
system_ids = conversation_result.pop("system", []) or []

Expand Down
164 changes: 150 additions & 14 deletions paddlenlp/transformers/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import numpy as np
import paddle
import six
from jinja2.exceptions import TemplateError
from jinja2 import Template
from jinja2.exceptions import TemplateError, TemplateSyntaxError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from paddle.utils import try_import

Expand Down Expand Up @@ -516,7 +517,7 @@ class ChatTemplate:

@staticmethod
@lru_cache()
def _compile_jinja_template(chat_template):
def _compile_jinja_template(chat_template) -> Template:
def raise_exception(message):
raise TemplateError(message)

Expand Down Expand Up @@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str,
raise ValueError(
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]"
)

if len(conversations[-1]) > 1:
logger.warning(
f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}"
Expand All @@ -623,7 +623,7 @@ class ChatTemplateMixin:

def apply_chat_template(
self,
conversation: List[List[str, str]] | str,
conversation: List[List[str, str] | Dict[str, str]] | str,
tokenize: bool = True,
context_data: Dict[str, Any] = {},
**tokenizer_kwargs
Expand All @@ -638,6 +638,26 @@ def apply_chat_template(
Returns:
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data
"""
if not self.chat_template:
raise ValueError("chat_template is not set, please set chat_template first.")
elif isinstance(self.chat_template, Template):
add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True)
query = self._apply_chat_template(conversation, context_data, add_generation_prompt=add_generation_prompt)
elif isinstance(self.chat_template, ChatTemplate):
query = self._apply_chat_template_paddle(conversation, context_data)

if not tokenize:
return query

# chat_template should not add special tokens
tokenizer_kwargs["add_special_tokens"] = False
return self(query, **tokenizer_kwargs)

def _apply_chat_template_paddle(
self,
conversation: List[List[str, str]] | str,
context_data: Dict[str, Any] = {},
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
context_data = self.chat_template._init_context_data(context_data)

if isinstance(conversation, str):
Expand All @@ -649,14 +669,45 @@ def apply_chat_template(
)

query = self.chat_template(conversation, context_data=context_data)
if not tokenize:
return query
return query

# chat_template should not add special tokens
tokenizer_kwargs["add_special_tokens"] = False
return self(query, **tokenizer_kwargs)
def _apply_chat_template(
self,
conversation: List[List[str, str] | Dict[str, str]] | str,
context_data: Dict[str, Any] = {},
add_generation_prompt=True,
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
if isinstance(conversation, str):
conversations = [{"role": "user", "content": conversation}]
elif isinstance(conversation, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此外,也测试过新旧 chat_template 在 Predictor 中的使用是否符合预期,同时还要测试一下 gradio_ui 能够使用新旧 chat_template。

conversations = []
for index, item in enumerate(conversation):
if isinstance(item, dict):
conversations = conversation
break
elif isinstance(item, list):
assert 1 <= len(item) <= 2
if isinstance(item[0], str):
conversations.append({"role": "user", "content": item[0]})
if len(item) == 2 and isinstance(item[1], str):
conversations.append({"role": "assistant", "content": item[1]})
else:
# item里只有一个元素,说明为最后一轮
if index != len(conversation) - 1:
raise ValueError(f"Round {index} has error round")
else:
raise ValueError("Each round in list should be string")
else:
raise ValueError(
"apply_chat_template do not support appling batch conversations, "
"so you should apply the conversation one by one."
)
query = self.chat_template.render(
messages=conversations, **self.special_tokens_map, add_generation_prompt=add_generation_prompt
)
return query

def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}, **kwargs):
"""Encodes conversation to pairs of token ids.
Turn 0: bos + system + sep + user bot + eos
Turn t: sep + bot + query bot + eos
Expand All @@ -668,6 +719,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
Returns:
List[list[int], list[int]]: the pair of input_ids and target_ids
"""
if not self.chat_template:
raise ValueError("chat_template is not set, please set chat_template first.")
elif isinstance(self.chat_template, Template):
add_generation_prompt = kwargs.pop("add_generation_prompt", True)
query = self._encode_chat_inputs(conversations, context_data, add_generation_prompt=add_generation_prompt)
elif isinstance(self.chat_template, ChatTemplate):
query = self._encode_chat_inputs_paddle(conversations, context_data)
return query

def _encode_chat_inputs_paddle(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
context_data = self.chat_template._init_context_data(context_data)
# encode system
result = {}
Expand All @@ -692,6 +753,70 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
result["conversations"] = conversation_ids
return result

def _encode_chat_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果脱离了之前设计的训推一体的 ChatTemplate,这个函数的适用性应该还挺低的,根本用不了。

所以,不太建议将 encode_chat_inputs 这块逻辑写到 tokenizer 里面去,尽量写到前处理里面去。

所以,这块的调整可能 就比较大了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到目前encode_chat_inputs函数使用较广,移除造成影响范围可能较大。是否可以考虑以下策略:
默认tgt src切分方式为 src中不含有bot start token:即tgt中含有完整的user轮+bot start token
如果需要重写,则在tokenizer类中单独定义:如qwen

self,
conversations: List[List[str, str]],
context_data: Dict[str, Any] = {},
system: str = None,
add_generation_prompt=True,
):
result = {}
# conversation = []
# if origin_msg[0]['role'] == 'system':
# system = origin_msg.pop(0)
if system:
try:
self.chat_template.render(messages={"role": "system", "content": system})
except Exception as e:
raise ValueError("System is not supported in this tokenizer.", e)

conversation_dict = []
origin_msg = []
for round in conversations:
round_role = [
{"role": "user", "content": round[0]},
{"role": "assistant", "content": round[1]},
]
origin_msg.extend(round_role)
conversation_dict.append(round_role)
ans = []

for conv in conversation_dict:
roundi = [system] + conv if system else conv
roundi_str = self.chat_template.render(
messages=roundi, add_generation_prompt=False, **self.special_tokens_map
)
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]]
roundi_no_ans_str = self.chat_template.render(
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)
ans_roundi = roundi_str[len(roundi_no_ans_str) :]
ans.append(ans_roundi)

regex_pattern = "|".join(map(re.escape, ans))
non_learnable_parts = re.split(
r"(?:%s)" % regex_pattern,
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map),
)
if non_learnable_parts[-1] == "":
non_learnable_parts.pop()

assert len(non_learnable_parts) == len(ans)

conversation_ids = []
for i in range(len(non_learnable_parts)):
conversation_ids.append(
self.batch_encode(
[non_learnable_parts[i], ans[i]],
add_special_tokens=False,
padding=False,
)["input_ids"]
)
# print(self.batch_decode(conversation_ids[i]))

result["conversations"] = conversation_ids
return result

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
Expand All @@ -713,6 +838,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if not os.path.exists(chat_template_file):
return tokenizer

if tokenizer.chat_template is not None:
logger.warning(
"Chat-template already exists in config file, it will be overwritten by chat_template.json file."
)
tokenizer.init_chat_template(chat_template_file)
return tokenizer

Expand All @@ -724,9 +853,16 @@ def init_chat_template(self, chat_template: str | dict):
"""
if isinstance(chat_template, str):
if not os.path.exists(chat_template):
raise FileNotFoundError("The chat-template file does not exist: {}".format(chat_template))

self.chat_template = ChatTemplate.from_file(chat_template)
try:
self.chat_template: Template = ChatTemplate._compile_jinja_template(chat_template)
except TemplateSyntaxError:
# It is neither jinjia string nor path string
raise TemplateSyntaxError(
"The chat-template in json is not valid jinja string: {}".format(chat_template),
lineno=0, # fake lineno, useless required msg
)
else:
self.chat_template = ChatTemplate.from_file(chat_template)
elif isinstance(chat_template, dict):
self.chat_template = ChatTemplate.from_dict(chat_template)
elif isinstance(chat_template, ChatTemplate):
Expand All @@ -737,7 +873,7 @@ def init_chat_template(self, chat_template: str | dict):
def save_resources(self, save_directory):
super().save_resources(save_directory)

if self.chat_template is not None:
if isinstance(self.chat_template, ChatTemplate): # Future remove if ChatTemplate is deprecated
chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME)
with open(chat_template_file, "w", encoding="utf-8") as f:
json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,9 @@ def convert_added_tokens(obj):

# TODO(guosheng): avoid reduplication of position args and key word args
tokenizer = cls(*init_args, **init_kwargs)
chat_template = init_kwargs.pop("chat_template", None)
if chat_template is not None:
tokenizer.init_chat_template(chat_template)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
Expand Down
72 changes: 72 additions & 0 deletions tests/transformers/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,75 @@ def test_inference_template_with_context_data(self):
final_query = tokenizer.apply_chat_template(query, context_data=context_data, tokenize=False)
expected_query = "你是一个人工智能助手<<SYSTEM-MESSAGE>>-<<INSTRUCTION-MESSAGE>>\nHuman: 你好<sep> Bot:"
self.assertEqual(final_query, expected_query)


class TemplateIntegrationTest(unittest.TestCase):
class DataArg:
def __init__(self, max_length, src_length: Optional[int] = None):
self.max_length: int = max_length
if src_length is None:
src_length = self.max_length - 8

self.src_length: int = src_length

def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("qwen/qwen-7b-chat")
qwen_jinja = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
self.tokenizer.init_chat_template(qwen_jinja)
return super().setUp()

def test_chat_template(self):
# test single turn
query = "你好"
final_query = self.tokenizer.apply_chat_template(query, tokenize=False)
expected_query = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"
self.assertEqual(final_query, expected_query)

# test multi turns conversation
query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"]]
final_query = self.tokenizer.apply_chat_template(query, tokenize=False)
expected_query = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n<|im_start|>user\n今天吃啥<|im_end|>\n<|im_start|>assistant\n"
self.assertEqual(final_query, expected_query)

def test_system_error(self):
# test system messaage error
error_jinja = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}"
self.tokenizer.init_chat_template(error_jinja)
from jinja2.exceptions import TemplateError

with self.assertRaises(TemplateError):
self.tokenizer.apply_chat_template([{"role": "system", "content": ""}])

def test_round_error(self):
# error round, 1 is not a valid role.
query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥"], ["你好", "您好"]]
with self.assertRaises(ValueError):
self.tokenizer.apply_chat_template(query, tokenize=False)

def test_jinja_syntax_error(self):
# test system messaage error
error_jinja = (
"{ bos_token }{% if messages[0]['role'] == 'system' %}{ raise_exception('System role not supported')}"
)
from jinja2.exceptions import TemplateSyntaxError

with self.assertRaises(TemplateSyntaxError):
self.tokenizer.init_chat_template(error_jinja)

def test_train_format(self):
from data import tokenize_rounds_example

fake_data_args = self.DataArg(50, src_length=50)
example = {"src": ["你好"], "tgt": ["您好,我是个人人工智能助手"]}
result, tgt_id = tokenize_rounds_example(self.tokenizer, example, fake_data_args, add_generation_prompt=True)
sentence_result = self.tokenizer.decode(result["input_ids"])
expected_sentence = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好,我是个人人工智能助手<|im_end|>\n"
self.assertEqual(expected_sentence, sentence_result)

tgt_idx = len(
self.tokenizer.encode(
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n"
)["input_ids"]
)
self.assertEqual(tgt_id[tgt_idx - 1], -100)
self.assertNotEqual(tgt_id[tgt_idx], -100)