Skip to content

Commit 94e798f

Browse files
authored
fix loraga merge (#9765)
* fix loraga merge * change sign
1 parent 027b530 commit 94e798f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,18 @@ def set_state_dict(self, state_dict):
327327
model_state_dict = self.model.state_dict()
328328
if self.lora_config.loraga:
329329

330-
def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
330+
def process_split_and_assign(name, concat_tensor, init_dict, state_dict):
331+
if "lora_A" in name:
332+
axis = 1
333+
else:
334+
axis = 0
331335
if isinstance(concat_tensor, np.ndarray):
332336
final_lora, init_lora = np.split(concat_tensor, 2, axis=axis)
333337
init_lora = paddle.to_tensor(init_lora)
334338
else:
335339
final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis)
340+
if "lora_B" in name:
341+
init_lora *= -1
336342
init_dict[name] = init_lora
337343
state_dict[name] = final_lora
338344
return init_lora
@@ -341,13 +347,13 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
341347
if "lora_A" in name:
342348
concat_lora_A = state_dict[name]
343349
init_loraA = process_split_and_assign(
344-
name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict
350+
name, concat_lora_A, init_dict=self.loraga_init_dict, state_dict=state_dict
345351
)
346352

347353
loraB_name = name.replace("lora_A", "lora_B")
348354
concat_lora_B = state_dict[loraB_name]
349355
init_loraB = process_split_and_assign(
350-
loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict
356+
loraB_name, concat_lora_B, init_dict=self.loraga_init_dict, state_dict=state_dict
351357
)
352358

353359
base_name = name.replace("lora_A", "weight")
@@ -690,7 +696,7 @@ def get_trainable_state_dict(self, concat_init_lora=False):
690696
if "lora_A" in name:
691697
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1)
692698
else:
693-
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0)
699+
trainable_state_dict[name] = paddle.concat([weight, -self.loraga_init_dict[name]], axis=0)
694700
else:
695701
trainable_state_dict[name] = weight
696702

0 commit comments

Comments
 (0)