We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 40fa402 commit d007615Copy full SHA for d007615
paddlenlp/peft/lora/loraga_utils.py
@@ -61,7 +61,6 @@ def estimate_gradient(self, model: PretrainedModel):
61
logger.info("Estimating gradient for LoraGA.")
62
63
model = self._wrap_model(model)
64
- model.train()
65
dataloader = self.get_train_dataloader()
66
iters = 0
67
@@ -75,9 +74,7 @@ def estimate_gradient(self, model: PretrainedModel):
75
74
for batch in dataloader:
76
iters += 1
77
# Pipeline parallel not supported currently
78
- with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list):
79
- loss, _ = model(**batch)
80
- loss.backward()
+ self.training_step(model, batch)
81
82
if iters == self.loraga_init_iters:
83
break
0 commit comments