diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index 7400e2e3b88d..72b4baac1de2 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -61,7 +61,6 @@ def estimate_gradient(self, model: PretrainedModel): logger.info("Estimating gradient for LoraGA.") model = self._wrap_model(model) - model.train() dataloader = self.get_train_dataloader() iters = 0 @@ -75,9 +74,7 @@ def estimate_gradient(self, model: PretrainedModel): for batch in dataloader: iters += 1 # Pipeline parallel not supported currently - with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list): - loss, _ = model(**batch) - loss.backward() + self.training_step(model, batch) if iters == self.loraga_init_iters: break