Skip to content

Commit d007615

Browse files
authored
fix loraga amp (#9699)
1 parent 40fa402 commit d007615

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

paddlenlp/peft/lora/loraga_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def estimate_gradient(self, model: PretrainedModel):
6161
logger.info("Estimating gradient for LoraGA.")
6262

6363
model = self._wrap_model(model)
64-
model.train()
6564
dataloader = self.get_train_dataloader()
6665
iters = 0
6766

@@ -75,9 +74,7 @@ def estimate_gradient(self, model: PretrainedModel):
7574
for batch in dataloader:
7675
iters += 1
7776
# Pipeline parallel not supported currently
78-
with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list):
79-
loss, _ = model(**batch)
80-
loss.backward()
77+
self.training_step(model, batch)
8178

8279
if iters == self.loraga_init_iters:
8380
break

0 commit comments

Comments
 (0)