@@ -132,7 +132,10 @@ class PredictorArgument:
132
132
133
133
@property
134
134
def total_max_length (self ):
135
- return 8192 # Maximum sequence length.
135
+ if self .device == "npu" :
136
+ return self .src_length + self .max_length
137
+ else :
138
+ return 8192 # Maximum sequence length.
136
139
137
140
138
141
@dataclass
@@ -859,6 +862,35 @@ def init_model_inputs(self, config: PredictorArgument):
859
862
self .model_inputs ["tgt_mask" ] = (
860
863
alibi_decoder + (1 - self .model_inputs ["tgt_mask" ]) * paddle .finfo (self .dtype ).min
861
864
).cast (self .dtype )
865
+ elif config .device == "npu" and self .model_config .get ("alibi" , False ):
866
+ lower_one_tril = paddle .tril (
867
+ paddle .ones (shape = (config .total_max_length , config .total_max_length ), dtype = self .dtype )
868
+ )
869
+ lower_one_tril = lower_one_tril [None , None , :, :]
870
+ src_mask = lower_one_tril .tile ([config .batch_size , 1 , 1 , 1 ])
871
+ tgt_mask = paddle .full (
872
+ shape = [config .batch_size , 1 , 1 , config .total_max_length ], fill_value = 1 , dtype = self .dtype
873
+ )
874
+ arange_tensor_encoder = paddle .arange (config .total_max_length ).astype (self .dtype )
875
+ alibi_slopes = llm_utils .get_alibi_slopes (self .num_attention_heads )
876
+ alibi = alibi_slopes [None , :, None , None ] * arange_tensor_encoder
877
+ alibi_encoder = alibi .tile ([config .batch_size , 1 , config .total_max_length , 1 ])
878
+ alibi_decoder = alibi .tile (
879
+ [
880
+ config .batch_size ,
881
+ 1 ,
882
+ 1 ,
883
+ 1 ,
884
+ ]
885
+ )
886
+ # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated!
887
+ src_mask = (
888
+ alibi_encoder + (1 - src_mask ) * paddle .finfo (self .dtype ).min
889
+ ).cast (self .dtype )
890
+ tgt_mask = (
891
+ alibi_decoder + (1 - tgt_mask ) * paddle .finfo (self .dtype ).min
892
+ ).cast (self .dtype )
893
+ self .model_inputs ["rope_emb" ] = paddle .concat ([src_mask .reshape ([- 1 ]), tgt_mask .reshape ([- 1 ])])
862
894
863
895
def _preprocess (self , input_text : list [str ]):
864
896
if self .tokenizer .chat_template is not None :
0 commit comments