From cf69b0b0d33dc7620dfd778cb15b82bf6c96e9a9 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Wed, 12 Jun 2024 17:58:51 +0800 Subject: [PATCH 1/3] [LLM] support sparse attention for LLAMA --- paddlenlp/transformers/llama/fusion_ops.py | 25 +++++--- paddlenlp/transformers/llama/modeling.py | 23 ++++++-- paddlenlp/transformers/llama/modeling_pp.py | 63 +++++++++++++++++++-- 3 files changed, 95 insertions(+), 16 deletions(-) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 182663bdbc73..15ef066df810 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -149,6 +149,7 @@ def fusion_flash_attention( sequence_parallel=False, reshard_layer=None, npu_is_casual=False, + attn_mask_start_row_indices=None, ): bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, _, _ = value_states.shape @@ -208,13 +209,23 @@ def fusion_flash_attention( is_causal=True, ) else: - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=attention_mask is None, - ) + if attn_mask_start_row_indices is not None: + assert alibi is None, "flash_attention_with_sparse_mask not support alibi" + attn_output = F.flash_attention_with_sparse_mask( + query_states, + key_states, + value_states, + attn_mask_start_row_indices=attn_mask_start_row_indices, + is_causal=True, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) attn_weights = None if reshard_layer is not None: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 98808a327a07..7b404f548448 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -212,6 +212,7 @@ def scaled_dot_product_attention( attention_mask, output_attentions, alibi=None, + attn_mask_start_row_indices=None, sequence_parallel=False, reshard_layer=None, npu_is_casual=False, @@ -228,6 +229,7 @@ def scaled_dot_product_attention( attention_mask, output_attentions, alibi, + attn_mask_start_row_indices, sequence_parallel, reshard_layer, npu_is_casual, @@ -815,6 +817,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, alibi: Optional[paddle.Tensor] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, npu_is_casual: bool = False, ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1013,6 +1016,7 @@ def forward( attention_mask, output_attentions, alibi, + attn_mask_start_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, use_reentrant=self.config.recompute_use_reentrant, @@ -1026,6 +1030,7 @@ def forward( attention_mask, output_attentions, alibi, + attn_mask_start_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, npu_is_casual=npu_is_casual, @@ -1081,6 +1086,7 @@ def forward( past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, alibi: Optional[paddle.Tensor] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, npu_is_casual: bool = False, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: """ @@ -1118,6 +1124,7 @@ def forward( output_attentions, use_cache, alibi, + attn_mask_start_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: @@ -1129,6 +1136,7 @@ def forward( output_attentions, use_cache, alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, npu_is_casual=npu_is_casual, ) @@ -1458,6 +1466,7 @@ def recompute_training_full( past_key_value: Tensor, use_cache: bool, alibi=None, + attn_mask_start_row_indices=None, ): def create_custom_forward(module): def custom_forward(*inputs): @@ -1474,6 +1483,7 @@ def custom_forward(*inputs): past_key_value, use_cache, alibi, + attn_mask_start_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) @@ -1490,6 +1500,7 @@ def forward( output_attentions=False, output_hidden_states=None, return_dict=False, + attn_mask_start_row_indices=None, **kwargs, ): if self.sequence_parallel and use_cache: @@ -1536,10 +1547,10 @@ def forward( if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") # embed positions - if attention_mask is None: + if attn_mask_start_row_indices is None and attention_mask is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if self.config.alibi: + if attn_mask_start_row_indices is None and self.config.alibi: if self.config.use_long_sequence_strategies: alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( self.config.long_sequence_strategy_type, @@ -1570,14 +1581,14 @@ def forward( if use_casual_mask: attention_mask = None - else: + elif attn_mask_start_row_indices is None: attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] is_casual = False - if self.config.use_flash_attention and get_env_device() != "gcu": + if attn_mask_start_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu": if use_casual_mask: is_casual = True else: @@ -1614,6 +1625,7 @@ def forward( past_key_value, use_cache, alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, ) else: layer_outputs = decoder_layer( @@ -1624,6 +1636,7 @@ def forward( past_key_value, use_cache, alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, npu_is_casual=is_casual, ) @@ -1881,6 +1894,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + attn_mask_start_row_indices=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1897,6 +1911,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_start_row_indices=attn_mask_start_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index a00d8fc01f76..30d7dbf72042 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + import paddle import paddle.distributed.fleet as fleet import paddle.nn as nn @@ -149,7 +151,7 @@ def forward(self, args): alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length]) alibi.stop_gradient = True - if attention_mask is not None: + if attention_mask is not None and attention_mask.dtype != paddle.int32: attention_mask = LlamaModel._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), 0, input_embeds.dtype ) @@ -175,23 +177,44 @@ def forward(self, args): # we can't distinguish # hidden_states, attention_mask, position_ids or # hidden_states, attention_mask, alibi + + if attention_mask is None: + attn_mask_start_row_indices = None + elif attention_mask.dtype == paddle.int32: + attn_mask_start_row_indices = attention_mask + else: + attn_mask_start_row_indices = None + if self.config.alibi and alibi is None and position_ids is not None: alibi = position_ids position_ids = None has_gradient = not hidden_states.stop_gradient if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: - if attention_mask is not None or alibi is not None: + if attention_mask is not None or alibi is not None or attn_mask_start_row_indices is not None: hidden_states = recompute( - super().forward, hidden_states, attention_mask=attention_mask, alibi=alibi, use_reentrant=False + super().forward, + hidden_states, + attention_mask=attention_mask, + alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, + use_reentrant=False, ) else: # for pretrain hidden_states = recompute( - super().forward, hidden_states, use_reentrant=self.config.recompute_use_reentrant + super().forward, + hidden_states, + attn_mask_start_row_indices=attn_mask_start_row_indices, + use_reentrant=self.config.recompute_use_reentrant, ) else: - hidden_states = super().forward(hidden_states, attention_mask=attention_mask, alibi=alibi) + hidden_states = super().forward( + hidden_states, + attention_mask=attention_mask, + alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, + ) return return_args(hidden_states, attention_mask, position_ids, alibi) @@ -222,6 +245,36 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # DONOT Add base_model_prefix !!!! + @classmethod + def _prepare_pipeline_inputs_func(cls, inputs): + first_stage_keys = ["input_ids", "attn_mask_start_row_indices", "position_ids"] + if type(inputs) is dict or type(inputs) is OrderedDict: + if "attention_mask" in inputs: + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + else: # inputs is list + if "attention_mask" in inputs[0]: + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + def __init__(self, config): self.config = config From 7f3014cbf72f0d9d46482a202012265ec1f52c70 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Tue, 18 Jun 2024 20:31:20 +0800 Subject: [PATCH 2/3] fix modeling_pp.py and rename attn_mask_start_row_indices --- paddlenlp/transformers/llama/fusion_ops.py | 6 +- paddlenlp/transformers/llama/modeling.py | 38 +++++------ paddlenlp/transformers/llama/modeling_pp.py | 70 +++++++++++---------- 3 files changed, 60 insertions(+), 54 deletions(-) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 4e19f7b8f7e7..2a273489e59b 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -146,7 +146,7 @@ def fusion_flash_attention( attention_mask, output_attentions, alibi=None, - attn_mask_start_row_indices=None, + attn_mask_startend_row_indices=None, sequence_parallel=False, reshard_layer=None, npu_is_casual=False, @@ -209,13 +209,13 @@ def fusion_flash_attention( is_causal=True, ) else: - if attn_mask_start_row_indices is not None: + if attn_mask_startend_row_indices is not None: assert alibi is None, "flash_attention_with_sparse_mask not support alibi" attn_output = F.flash_attention_with_sparse_mask( query_states, key_states, value_states, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_start_row_indices=attn_mask_startend_row_indices, is_causal=True, ) else: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index afdae162c74c..56c9713f0118 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -212,7 +212,7 @@ def scaled_dot_product_attention( attention_mask, output_attentions, alibi=None, - attn_mask_start_row_indices=None, + attn_mask_startend_row_indices=None, sequence_parallel=False, reshard_layer=None, npu_is_casual=False, @@ -229,7 +229,7 @@ def scaled_dot_product_attention( attention_mask, output_attentions, alibi, - attn_mask_start_row_indices, + attn_mask_startend_row_indices, sequence_parallel, reshard_layer, npu_is_casual, @@ -818,7 +818,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, alibi: Optional[paddle.Tensor] = None, - attn_mask_start_row_indices: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, npu_is_casual: bool = False, ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1017,7 +1017,7 @@ def forward( attention_mask, output_attentions, alibi, - attn_mask_start_row_indices, + attn_mask_startend_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, use_reentrant=self.config.recompute_use_reentrant, @@ -1031,7 +1031,7 @@ def forward( attention_mask, output_attentions, alibi, - attn_mask_start_row_indices, + attn_mask_startend_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, npu_is_casual=npu_is_casual, @@ -1087,7 +1087,7 @@ def forward( past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, alibi: Optional[paddle.Tensor] = None, - attn_mask_start_row_indices: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, npu_is_casual: bool = False, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: """ @@ -1125,7 +1125,7 @@ def forward( output_attentions, use_cache, alibi, - attn_mask_start_row_indices, + attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: @@ -1137,7 +1137,7 @@ def forward( output_attentions, use_cache, alibi, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, npu_is_casual=npu_is_casual, ) @@ -1467,7 +1467,7 @@ def recompute_training_full( past_key_value: Tensor, use_cache: bool, alibi=None, - attn_mask_start_row_indices=None, + attn_mask_startend_row_indices=None, ): def create_custom_forward(module): def custom_forward(*inputs): @@ -1484,7 +1484,7 @@ def custom_forward(*inputs): past_key_value, use_cache, alibi, - attn_mask_start_row_indices, + attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) @@ -1501,7 +1501,7 @@ def forward( output_attentions=False, output_hidden_states=None, return_dict=False, - attn_mask_start_row_indices=None, + attn_mask_startend_row_indices=None, **kwargs, ): if self.sequence_parallel and use_cache: @@ -1548,10 +1548,10 @@ def forward( if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") # embed positions - if attn_mask_start_row_indices is None and attention_mask is None: + if attn_mask_startend_row_indices is None and attention_mask is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if attn_mask_start_row_indices is None and self.config.alibi: + if attn_mask_startend_row_indices is None and self.config.alibi: if self.config.use_long_sequence_strategies: alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( self.config.long_sequence_strategy_type, @@ -1582,14 +1582,14 @@ def forward( if use_casual_mask: attention_mask = None - elif attn_mask_start_row_indices is None: + elif attn_mask_startend_row_indices is None: attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] is_casual = False - if attn_mask_start_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu": + if attn_mask_startend_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu": if use_casual_mask: is_casual = True else: @@ -1626,7 +1626,7 @@ def forward( past_key_value, use_cache, alibi=alibi, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) else: layer_outputs = decoder_layer( @@ -1637,7 +1637,7 @@ def forward( past_key_value, use_cache, alibi=alibi, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, npu_is_casual=is_casual, ) @@ -1899,7 +1899,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - attn_mask_start_row_indices=None, + attn_mask_startend_row_indices=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1916,7 +1916,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 5445f7d8d279..c9811d1bc43b 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -49,18 +49,23 @@ def __repr__(self): def parse_args(args): if isinstance(args, tuple): - if len(args) == 4: - hidden_states, attention_mask, position_ids, alibi = args - if len(args) == 3: - hidden_states, attention_mask, position_ids = args + if len(args) == 5: + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = args + elif len(args) == 4: + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args + alibi = None + elif len(args) == 3: + hidden_states, attention_mask, attn_mask_startend_row_indices = args + position_ids = None alibi = None elif len(args) == 2: hidden_states, attention_mask = args + attn_mask_startend_row_indices = None position_ids = None alibi = None else: hidden_states = args - attention_mask, position_ids, alibi = None, None, None + attention_mask, attn_mask_startend_row_indices, position_ids, alibi = None, None, None, None if position_ids is not None: position_ids.stop_gradient = True @@ -68,17 +73,24 @@ def parse_args(args): if attention_mask is not None: attention_mask.stop_gradient = True + if attn_mask_startend_row_indices is not None: + attn_mask_startend_row_indices.stop_gradient = True + if alibi is not None: alibi.stop_gradient = True - return hidden_states, attention_mask, position_ids, alibi + return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi -def return_args(hidden_states, attention_mask=None, position_ids=None, alibi=None): +def return_args( + hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None, alibi=None +): ret = (hidden_states,) if attention_mask is not None: ret += (attention_mask.clone(),) + if attn_mask_startend_row_indices is not None: + ret += (attn_mask_startend_row_indices.clone(),) if position_ids is not None: ret += (position_ids.clone(),) if alibi is not None: @@ -116,7 +128,7 @@ def forward(self, args): Returns: _type_: _description_ """ - input_ids, attention_mask, position_ids, alibi = parse_args(args) + input_ids, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args) input_embeds = self.embed_tokens(input_ids) if self.sequence_parallel: from paddlenlp.transformers import ScatterOp @@ -130,6 +142,9 @@ def forward(self, args): batch_size, seq_length = input_ids.shape alibi = None if self.config.alibi: + assert ( + attn_mask_startend_row_indices is not None + ), "alibi and attn_mask_startend_row_indices can not be set at same time" # embed positions mask = ( attention_mask @@ -151,7 +166,10 @@ def forward(self, args): alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length]) alibi.stop_gradient = True - if attention_mask is not None and attention_mask.dtype != paddle.int32: + if attention_mask is not None: + assert ( + attn_mask_startend_row_indices is not None + ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" attention_mask = LlamaModel._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), 0, input_embeds.dtype ) @@ -168,37 +186,30 @@ def forward(self, args): ) attention_mask.stop_gradient = True - return return_args(input_embeds, attention_mask, position_ids, alibi) + return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids, alibi) class LlamaDecoderLayerPipe(LlamaDecoderLayer): def forward(self, args): - hidden_states, attention_mask, position_ids, alibi = parse_args(args) + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args) # we can't distinguish # hidden_states, attention_mask, position_ids or # hidden_states, attention_mask, alibi - if attention_mask is None: - attn_mask_start_row_indices = None - elif attention_mask.dtype == paddle.int32: - attn_mask_start_row_indices = attention_mask - else: - attn_mask_start_row_indices = None - if self.config.alibi and alibi is None and position_ids is not None: alibi = position_ids position_ids = None has_gradient = not hidden_states.stop_gradient if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: - if attention_mask is not None or alibi is not None or attn_mask_start_row_indices is not None: + if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, hidden_states, position_ids=position_ids, attention_mask=attention_mask, alibi=alibi, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=False, ) else: @@ -207,7 +218,7 @@ def forward(self, args): super().forward, hidden_states, position_ids=position_ids, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: @@ -216,10 +227,10 @@ def forward(self, args): position_ids=position_ids, attention_mask=attention_mask, alibi=alibi, - attn_mask_start_row_indices=attn_mask_start_row_indices, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) - return return_args(hidden_states, attention_mask, position_ids, alibi) + return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi) class LlamaRMSNormPipe(nn.Layer): @@ -228,7 +239,7 @@ def __init__(self, config): self.norm = LlamaRMSNorm(config) def forward(self, args): - hidden_states, attention_mask, position_ids, alibi = parse_args(args) + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args) return self.norm(hidden_states) @@ -250,17 +261,12 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): @classmethod def _prepare_pipeline_inputs_func(cls, inputs): - first_stage_keys = ["input_ids", "attn_mask_start_row_indices", "position_ids"] - if type(inputs) is dict or type(inputs) is OrderedDict: - if "attention_mask" in inputs: - first_stage_keys = ["input_ids", "attention_mask", "position_ids"] - else: # inputs is list - if "attention_mask" in inputs[0]: - first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + + first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] last_stage_keys = ["labels"] def get_expected_keys(inputs, keys): - ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + ret = tuple([inputs.pop(k) if k in inputs else None for k in keys]) if len(ret) == 1: ret = ret[0] return ret From 84cf5a89170004dc02fc9d3f94c2124320979bff Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Wed, 19 Jun 2024 11:02:19 +0800 Subject: [PATCH 3/3] fix assert --- paddlenlp/transformers/llama/modeling_pp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index c9811d1bc43b..72ecdf20e8e5 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -143,7 +143,7 @@ def forward(self, args): alibi = None if self.config.alibi: assert ( - attn_mask_startend_row_indices is not None + attn_mask_startend_row_indices is None ), "alibi and attn_mask_startend_row_indices can not be set at same time" # embed positions mask = ( @@ -168,7 +168,7 @@ def forward(self, args): if attention_mask is not None: assert ( - attn_mask_startend_row_indices is not None + attn_mask_startend_row_indices is None ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" attention_mask = LlamaModel._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), 0, input_embeds.dtype