@@ -327,12 +327,18 @@ def set_state_dict(self, state_dict):
327
327
model_state_dict = self .model .state_dict ()
328
328
if self .lora_config .loraga :
329
329
330
- def process_split_and_assign (name , concat_tensor , axis , init_dict , state_dict ):
330
+ def process_split_and_assign (name , concat_tensor , init_dict , state_dict ):
331
+ if "lora_A" in name :
332
+ axis = 1
333
+ else :
334
+ axis = 0
331
335
if isinstance (concat_tensor , np .ndarray ):
332
336
final_lora , init_lora = np .split (concat_tensor , 2 , axis = axis )
333
337
init_lora = paddle .to_tensor (init_lora )
334
338
else :
335
339
final_lora , init_lora = paddle .split (concat_tensor , 2 , axis = axis )
340
+ if "lora_B" in name :
341
+ init_lora *= - 1
336
342
init_dict [name ] = init_lora
337
343
state_dict [name ] = final_lora
338
344
return init_lora
@@ -341,13 +347,13 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
341
347
if "lora_A" in name :
342
348
concat_lora_A = state_dict [name ]
343
349
init_loraA = process_split_and_assign (
344
- name , concat_lora_A , axis = 1 , init_dict = self .loraga_init_dict , state_dict = state_dict
350
+ name , concat_lora_A , init_dict = self .loraga_init_dict , state_dict = state_dict
345
351
)
346
352
347
353
loraB_name = name .replace ("lora_A" , "lora_B" )
348
354
concat_lora_B = state_dict [loraB_name ]
349
355
init_loraB = process_split_and_assign (
350
- loraB_name , concat_lora_B , axis = 0 , init_dict = self .loraga_init_dict , state_dict = state_dict
356
+ loraB_name , concat_lora_B , init_dict = self .loraga_init_dict , state_dict = state_dict
351
357
)
352
358
353
359
base_name = name .replace ("lora_A" , "weight" )
@@ -690,7 +696,7 @@ def get_trainable_state_dict(self, concat_init_lora=False):
690
696
if "lora_A" in name :
691
697
trainable_state_dict [name ] = paddle .concat ([weight , self .loraga_init_dict [name ]], axis = 1 )
692
698
else :
693
- trainable_state_dict [name ] = paddle .concat ([weight , self .loraga_init_dict [name ]], axis = 0 )
699
+ trainable_state_dict [name ] = paddle .concat ([weight , - self .loraga_init_dict [name ]], axis = 0 )
694
700
else :
695
701
trainable_state_dict [name ] = weight
696
702
0 commit comments