31
31
import numpy as np
32
32
import paddle
33
33
import six
34
- from jinja2 .exceptions import TemplateError
34
+ from jinja2 import Template
35
+ from jinja2 .exceptions import TemplateError , TemplateSyntaxError
35
36
from jinja2 .sandbox import ImmutableSandboxedEnvironment
36
37
from paddle .utils import try_import
37
38
58
59
TextInputPair ,
59
60
TruncationStrategy ,
60
61
)
61
- from .utils import InitTrackerMeta , fn_args_to_dict
62
+ from .utils import InitTrackerMeta , convert_to_dict_message , fn_args_to_dict
62
63
63
64
__all__ = [
64
65
"PretrainedTokenizer" ,
@@ -516,7 +517,7 @@ class ChatTemplate:
516
517
517
518
@staticmethod
518
519
@lru_cache ()
519
- def _compile_jinja_template (chat_template ):
520
+ def _compile_jinja_template (chat_template ) -> Template :
520
521
def raise_exception (message ):
521
522
raise TemplateError (message )
522
523
@@ -598,7 +599,6 @@ def __call__(self, conversations: list[list[str]] | str, context_data: Dict[str,
598
599
raise ValueError (
599
600
"The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]"
600
601
)
601
-
602
602
if len (conversations [- 1 ]) > 1 :
603
603
logger .warning (
604
604
f"The last conversation is not a single-round, chat-template will skip the conversation: { conversations [- 1 ][1 :]} "
@@ -623,7 +623,7 @@ class ChatTemplateMixin:
623
623
624
624
def apply_chat_template (
625
625
self ,
626
- conversation : List [List [str , str ]] | str ,
626
+ conversation : List [List [str , str ] | Dict [ str , str ] ] | str ,
627
627
tokenize : bool = True ,
628
628
context_data : Dict [str , Any ] = {},
629
629
** tokenizer_kwargs
@@ -638,6 +638,26 @@ def apply_chat_template(
638
638
Returns:
639
639
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data
640
640
"""
641
+ if not self .chat_template :
642
+ raise ValueError ("chat_template is not set, please set chat_template first." )
643
+ elif isinstance (self .chat_template , Template ):
644
+ add_generation_prompt = tokenizer_kwargs .pop ("add_generation_prompt" , True )
645
+ query = self ._apply_chat_template (conversation , add_generation_prompt = add_generation_prompt )
646
+ elif isinstance (self .chat_template , ChatTemplate ):
647
+ query = self ._apply_chat_template_paddle (conversation , context_data )
648
+
649
+ if not tokenize :
650
+ return query
651
+
652
+ # chat_template should not add special tokens
653
+ tokenizer_kwargs ["add_special_tokens" ] = False
654
+ return self (query , ** tokenizer_kwargs )
655
+
656
+ def _apply_chat_template_paddle (
657
+ self ,
658
+ conversation : List [List [str , str ]] | str ,
659
+ context_data : Dict [str , Any ] = {},
660
+ ) -> str | dict [str , numpy .ndarray | paddle .Tensor ]:
641
661
context_data = self .chat_template ._init_context_data (context_data )
642
662
643
663
if isinstance (conversation , str ):
@@ -649,14 +669,32 @@ def apply_chat_template(
649
669
)
650
670
651
671
query = self .chat_template (conversation , context_data = context_data )
652
- if not tokenize :
653
- return query
672
+ return query
654
673
655
- # chat_template should not add special tokens
656
- tokenizer_kwargs ["add_special_tokens" ] = False
657
- return self (query , ** tokenizer_kwargs )
674
+ def _apply_chat_template (
675
+ self ,
676
+ conversation : List [List [str , str ] | Dict [str , str ]] | str ,
677
+ add_generation_prompt = True ,
678
+ ) -> str | dict [str , numpy .ndarray | paddle .Tensor ]:
679
+ if isinstance (conversation , str ):
680
+ conversations = [{"role" : "user" , "content" : conversation }]
681
+ elif isinstance (conversation , list ):
682
+ assert len (conversation ) > 0 , "empty conversation is not allowed"
683
+ if isinstance (conversation [0 ], list ):
684
+ conversations = convert_to_dict_message (conversation )
685
+ elif isinstance (conversation [0 ], dict ):
686
+ conversations = conversation
687
+ else :
688
+ raise ValueError (
689
+ "apply_chat_template do not support appling batch conversations, "
690
+ "so you should apply the conversation one by one."
691
+ )
692
+ query = self .chat_template .render (
693
+ messages = conversations , ** self .special_tokens_map , add_generation_prompt = add_generation_prompt
694
+ )
695
+ return query
658
696
659
- def encode_chat_inputs (self , conversations : List [List [str , str ]], context_data : Dict [str , Any ] = {}):
697
+ def encode_chat_inputs (self , conversations : List [List [str , str ]], context_data : Dict [str , Any ] = {}, ** kwargs ):
660
698
"""Encodes conversation to pairs of token ids.
661
699
Turn 0: bos + system + sep + user bot + eos
662
700
Turn t: sep + bot + query bot + eos
@@ -668,6 +706,16 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
668
706
Returns:
669
707
List[list[int], list[int]]: the pair of input_ids and target_ids
670
708
"""
709
+ if not self .chat_template :
710
+ raise ValueError ("chat_template is not set, please set chat_template first." )
711
+ elif isinstance (self .chat_template , Template ):
712
+ add_generation_prompt = kwargs .pop ("add_generation_prompt" , True )
713
+ query = self ._encode_chat_inputs (conversations , context_data , add_generation_prompt = add_generation_prompt )
714
+ elif isinstance (self .chat_template , ChatTemplate ):
715
+ query = self ._encode_chat_inputs_paddle (conversations , context_data )
716
+ return query
717
+
718
+ def _encode_chat_inputs_paddle (self , conversations : List [List [str , str ]], context_data : Dict [str , Any ] = {}):
671
719
context_data = self .chat_template ._init_context_data (context_data )
672
720
# encode system
673
721
result = {}
@@ -692,6 +740,77 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
692
740
result ["conversations" ] = conversation_ids
693
741
return result
694
742
743
+ def _encode_chat_inputs (
744
+ self ,
745
+ conversations : List [List [str , str ]],
746
+ context_data : Dict [str , Any ] = {},
747
+ system : str = None ,
748
+ add_generation_prompt = True ,
749
+ ):
750
+ result = {}
751
+
752
+ # Some template do not support system msg, so we need to check it first.
753
+ if system :
754
+ try :
755
+ self .chat_template .render (messages = {"role" : "system" , "content" : system })
756
+ except Exception as e :
757
+ raise ValueError ("System is not supported in this tokenizer." , e )
758
+
759
+ # convert list msg to role dict msg
760
+ conversation_dict = []
761
+ origin_msg = []
762
+ for round in conversations :
763
+ round_role = [
764
+ {"role" : "user" , "content" : round [0 ]},
765
+ {"role" : "assistant" , "content" : round [1 ]},
766
+ ]
767
+ origin_msg .extend (round_role )
768
+ conversation_dict .append (round_role )
769
+ ans = []
770
+
771
+ # get answer in single round, then compile the chat entirely and split by single round ans
772
+ # attention: answer should include end token!
773
+ for conv in conversation_dict :
774
+ roundi = [system ] + conv if system else conv
775
+ roundi_str = self .chat_template .render (
776
+ messages = roundi , add_generation_prompt = False , ** self .special_tokens_map
777
+ )
778
+ roundi_no_ans = [system ] + [conv [0 ]] if system else [conv [0 ]]
779
+ roundi_no_ans_str = self .chat_template .render (
780
+ messages = roundi_no_ans , add_generation_prompt = add_generation_prompt , ** self .special_tokens_map
781
+ )
782
+ ans_roundi = roundi_str [len (roundi_no_ans_str ) :]
783
+ ans .append (ans_roundi )
784
+
785
+ non_learnable_parts = self ._extract_non_learnable_parts (origin_msg , ans )
786
+ assert len (non_learnable_parts ) == len (ans )
787
+
788
+ conversation_ids = []
789
+ for i in range (len (non_learnable_parts )):
790
+ conversation_ids .append (
791
+ self .batch_encode (
792
+ [non_learnable_parts [i ], ans [i ]],
793
+ add_special_tokens = False ,
794
+ padding = False ,
795
+ )["input_ids" ]
796
+ )
797
+
798
+ result ["conversations" ] = conversation_ids
799
+ return result
800
+
801
+ def _extract_non_learnable_parts (self , origin_msg : List [Dict [str , str ]], split_s : List [str ]):
802
+ """Split the entire chat by specified words. Extract the non-learnable parts."""
803
+ # distingish and replace the special words in original string to an uncompiled form: Like | -> \|
804
+ regex_pattern = "|" .join (map (re .escape , split_s ))
805
+ # splited by replaced specified words
806
+ non_learnable_parts = re .split (
807
+ r"(?:%s)" % regex_pattern ,
808
+ self .chat_template .render (messages = origin_msg , add_generation_prompt = False , ** self .special_tokens_map ),
809
+ )
810
+ if non_learnable_parts [- 1 ] == "" :
811
+ non_learnable_parts .pop ()
812
+ return non_learnable_parts
813
+
695
814
@classmethod
696
815
def from_pretrained (cls , pretrained_model_name_or_path , * args , ** kwargs ):
697
816
cache_dir = kwargs .pop ("cache_dir" , None )
@@ -713,6 +832,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
713
832
if not os .path .exists (chat_template_file ):
714
833
return tokenizer
715
834
835
+ if tokenizer .chat_template is not None :
836
+ logger .warning (
837
+ "Chat-template already exists in config file, it will be overwritten by chat_template.json file."
838
+ )
839
+ logger .warning (
840
+ "`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`."
841
+ )
716
842
tokenizer .init_chat_template (chat_template_file )
717
843
return tokenizer
718
844
@@ -724,9 +850,16 @@ def init_chat_template(self, chat_template: str | dict):
724
850
"""
725
851
if isinstance (chat_template , str ):
726
852
if not os .path .exists (chat_template ):
727
- raise FileNotFoundError ("The chat-template file does not exist: {}" .format (chat_template ))
728
-
729
- self .chat_template = ChatTemplate .from_file (chat_template )
853
+ try :
854
+ self .chat_template : Template = ChatTemplate ._compile_jinja_template (chat_template )
855
+ except TemplateSyntaxError :
856
+ # It is neither jinjia string nor path string
857
+ raise TemplateSyntaxError (
858
+ "The chat-template in json is not valid jinja string: {}" .format (chat_template ),
859
+ lineno = 0 , # fake lineno, useless required msg
860
+ )
861
+ else :
862
+ self .chat_template = ChatTemplate .from_file (chat_template )
730
863
elif isinstance (chat_template , dict ):
731
864
self .chat_template = ChatTemplate .from_dict (chat_template )
732
865
elif isinstance (chat_template , ChatTemplate ):
@@ -737,7 +870,7 @@ def init_chat_template(self, chat_template: str | dict):
737
870
def save_resources (self , save_directory ):
738
871
super ().save_resources (save_directory )
739
872
740
- if self .chat_template is not None :
873
+ if isinstance ( self .chat_template , ChatTemplate ): # Future remove if ChatTemplate is deprecated
741
874
chat_template_file = os .path .join (save_directory , CHAT_TEMPLATE_CONFIG_NAME )
742
875
with open (chat_template_file , "w" , encoding = "utf-8" ) as f :
743
876
json .dump (asdict (self .chat_template ), f , ensure_ascii = False , indent = 4 )
0 commit comments