diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 21635da46cca..05cf64d831d0 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -854,7 +854,7 @@ def get_layer_pp_info(layer_index): self.next_pp_stage_indexes = [] for i in range(config.num_hidden_layers): pp_stage_id, input_need_reshard = get_layer_pp_info(i) - decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id)) + decoder_layers.append(LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, pp_stage_id)) if input_need_reshard: self.next_pp_stage_indexes.append(i)