We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f4a8f4c commit 1ef7503Copy full SHA for 1ef7503
paddlenlp/transformers/llama/modeling_auto.py
@@ -854,7 +854,7 @@ def get_layer_pp_info(layer_index):
854
self.next_pp_stage_indexes = []
855
for i in range(config.num_hidden_layers):
856
pp_stage_id, input_need_reshard = get_layer_pp_info(i)
857
- decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id))
+ decoder_layers.append(LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, pp_stage_id))
858
if input_need_reshard:
859
self.next_pp_stage_indexes.append(i)
860
0 commit comments