13
13
# limitations under the License.
14
14
15
15
import math
16
- import os
17
16
from typing import List , Optional
18
17
19
18
import paddle
25
24
RowParallelLinear ,
26
25
)
27
26
28
- from .lora_quick_layers import quick_lora
27
+ try :
28
+ from paddle .distributed .fleet .utils .sequence_parallel_utils import (
29
+ AllGatherOp ,
30
+ ColumnSequenceParallelLinear ,
31
+ ReduceScatterOp ,
32
+ RowSequenceParallelLinear ,
33
+ mark_as_sequence_parallel_parameter ,
34
+ )
35
+ except :
36
+ pass
37
+
38
+ from paddlenlp .transformers .mc2_parallel_linear import (
39
+ MC2ColumnParallelCoreLinear ,
40
+ MC2ColumnSeqParallelCoreLinear ,
41
+ MC2RowParallelCoreLinear ,
42
+ MC2RowSeqParallelCoreLinear ,
43
+ )
29
44
30
- if "npu" in paddle .device .get_all_custom_device_type ():
31
- from .mc2_lora_npu import MC2LoRaColumnParallelLinear , MC2LoRaRowParallelLinear
32
- else :
33
- MC2LoRaRowParallelLinear = None
34
- MC2LoRaColumnParallelLinear = None
45
+ from .lora_quick_layers import quick_lora
35
46
36
47
37
48
class LoRALinear (nn .Linear ):
@@ -266,16 +277,16 @@ def forward(self, x: paddle.Tensor):
266
277
)
267
278
else :
268
279
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
269
- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
270
- output = MC2LoRaRowParallelLinear .apply (input_mp , self .weight , self .model_parallel_group )
271
- else :
280
+ if MC2RowParallelCoreLinear is None :
272
281
result_mp = F .linear (x = input_mp , weight = self .weight , name = self .name )
273
282
output = mp_ops ._mp_allreduce (
274
283
result_mp ,
275
284
group = self .model_parallel_group ,
276
285
use_calc_stream = True ,
277
286
use_model_parallel = True ,
278
287
)
288
+ else :
289
+ output = MC2RowParallelCoreLinear .apply (input_mp , self .weight , self .model_parallel_group )
279
290
280
291
if not self .merged :
281
292
# x @ A: [bz, in_f/ ws] ===> [bz, r]
@@ -298,6 +309,120 @@ def extra_repr(self):
298
309
return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
299
310
300
311
312
+ class RowSequenceParallelLoRALinear (RowSequenceParallelLinear ):
313
+ def __init__ (
314
+ self ,
315
+ in_features : int ,
316
+ out_features : int ,
317
+ r : int = 0 ,
318
+ lora_alpha : int = 1 ,
319
+ lora_dropout : float = 0.0 ,
320
+ rslora : bool = False ,
321
+ lora_plus_scale : float = 1.0 ,
322
+ merge_weights : bool = True ,
323
+ use_quick_lora : bool = False ,
324
+ ** kwargs
325
+ ):
326
+ RowSequenceParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
327
+ if not isinstance (r , int ) or r <= 0 :
328
+ raise ValueError ("Lora rank r should be a positive integer" )
329
+ self .r = r
330
+ self .lora_alpha = lora_alpha
331
+ # Optional dropout
332
+ if lora_dropout > 0.0 :
333
+ self .lora_dropout = nn .Dropout (p = lora_dropout )
334
+ else :
335
+ self .lora_dropout = lambda x : x
336
+ # Mark the weight as unmerged
337
+ self .merged = False
338
+ self .merge_weights = merge_weights
339
+
340
+ # compatible
341
+ self .name = self ._name
342
+
343
+ # Actual trainable parameters
344
+ self .lora_A = self .create_parameter (
345
+ shape = [self .input_size_per_partition , r ],
346
+ dtype = self ._dtype ,
347
+ is_bias = False ,
348
+ attr = paddle .ParamAttr (
349
+ initializer = nn .initializer .KaimingUniform (negative_slope = math .sqrt (5 ), nonlinearity = "leaky_relu" )
350
+ ),
351
+ )
352
+ self .lora_B = self .create_parameter (
353
+ shape = [r , self .out_features ],
354
+ dtype = self ._dtype ,
355
+ is_bias = False ,
356
+ attr = paddle .ParamAttr (
357
+ initializer = paddle .nn .initializer .Constant (value = 0.0 ),
358
+ learning_rate = lora_plus_scale ,
359
+ ),
360
+ )
361
+
362
+ self .lora_A .is_distributed = True
363
+ self .lora_A .split_axis = 0
364
+ self .lora_B .is_distributed = False
365
+ mark_as_sequence_parallel_parameter (self .lora_B )
366
+ if not rslora :
367
+ self .scaling = self .lora_alpha / self .r
368
+ else :
369
+ self .scaling = self .lora_alpha / math .sqrt (self .r )
370
+
371
+ # Freezing the pre-trained weight matrix
372
+ self .weight .stop_gradient = True
373
+ self ._use_quick_lora = use_quick_lora and lora_dropout == 0.0
374
+
375
+ @property
376
+ def use_quick_lora (self ):
377
+ # TODO(@gexiao): support qlora
378
+ return False # self._use_quick_lora and self.training and not self.merged
379
+
380
+ def train (self ):
381
+ super ().train ()
382
+ if self .merge_weights and self .merged :
383
+ # Make sure that the weights are not merged
384
+ new_weight = self .weight - self .lora_A @ self .lora_B * self .scaling
385
+ self .weight .set_value (new_weight )
386
+ self .merged = False
387
+
388
+ def eval (self ):
389
+ super ().eval ()
390
+ if self .merge_weights and not self .merged :
391
+ # Merge the weights and mark it
392
+ new_weight = self .weight + self .lora_A @ self .lora_B * self .scaling
393
+ self .weight .set_value (new_weight )
394
+ self .merged = True
395
+
396
+ def forward (self , x : paddle .Tensor ):
397
+ if not self .input_is_parallel :
398
+ input_mp = mp_ops ._c_split (x , group = self .model_parallel_group )
399
+ else :
400
+ input_mp = x
401
+
402
+ if MC2RowSeqParallelCoreLinear is None :
403
+ output_parallel = self .linear (input_mp , self .weight , name = self ._name )
404
+ output_ = ReduceScatterOp .apply (output_parallel )
405
+ result_mp = output_ + self .bias if self .bias is not None else output_
406
+ else :
407
+ output_ = MC2RowSeqParallelCoreLinear .apply (input_mp , self .weight , self .model_parallel_group )
408
+ result_mp = output_ + self .bias if self .bias is not None else output_
409
+
410
+ if not self .merged :
411
+ input_mp = self .lora_dropout (input_mp )
412
+ if MC2RowSeqParallelCoreLinear is None :
413
+ input_mp = input_mp @ self .lora_A
414
+ input_mp = ReduceScatterOp .apply (input_mp )
415
+ else :
416
+ input_mp = MC2RowSeqParallelCoreLinear .apply (input_mp , self .lora_A , self .model_parallel_group )
417
+ delta_mp = (input_mp @ self .lora_B ) * self .scaling
418
+ result_mp += delta_mp
419
+ return result_mp
420
+
421
+ def extra_repr (self ):
422
+ name = f", name={ self .name } " if self .name else ""
423
+ return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
424
+
425
+
301
426
class ColumnParallelLoRALinear (ColumnParallelLinear ):
302
427
def __init__ (
303
428
self ,
@@ -400,21 +525,21 @@ def forward(self, input: paddle.Tensor):
400
525
world_size = self .world_size ,
401
526
)
402
527
else :
403
- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
404
- res_mp = MC2LoRaColumnParallelLinear .apply (input , self .weight , self .model_parallel_group )
405
- result_mp = res_mp + self .bias
406
- else :
528
+ if MC2ColumnParallelCoreLinear is None :
407
529
input_mp = mp_ops ._c_identity (input , group = self .model_parallel_group )
408
530
result_mp = F .linear (x = input_mp , weight = self .weight , bias = self .bias , name = self .name )
531
+ else :
532
+ res_mp = MC2ColumnParallelCoreLinear .apply (input , self .weight , self .model_parallel_group )
533
+ result_mp = res_mp + self .bias
409
534
410
535
if not self .merged :
411
536
input_a = self .lora_dropout (input ) @ self .lora_A
412
- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
413
- tmp = MC2LoRaColumnParallelLinear .apply (input_a , self .lora_B , self .model_parallel_group )
414
- delta_mp = tmp * self .scaling
415
- else :
537
+ if MC2ColumnParallelCoreLinear is None :
416
538
input_a_mp = mp_ops ._c_identity (input_a , group = self .model_parallel_group )
417
539
delta_mp = (input_a_mp @ self .lora_B ) * self .scaling
540
+ else :
541
+ tmp = MC2ColumnParallelCoreLinear .apply (input_a , self .lora_B , self .model_parallel_group )
542
+ delta_mp = tmp * self .scaling
418
543
result_mp += delta_mp
419
544
420
545
if self .gather_output and self .is_mp :
@@ -428,6 +553,123 @@ def extra_repr(self):
428
553
return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
429
554
430
555
556
+ class ColumnSequenceParallelLoRALinear (ColumnSequenceParallelLinear ):
557
+ def __init__ (
558
+ self ,
559
+ in_features : int ,
560
+ out_features : int ,
561
+ r : int = 0 ,
562
+ lora_alpha : int = 1 ,
563
+ lora_dropout : float = 0.0 ,
564
+ rslora : bool = False ,
565
+ lora_plus_scale : float = 1.0 ,
566
+ merge_weights : bool = True ,
567
+ lora_A_weight_attr : Optional [paddle .ParamAttr ] = None ,
568
+ use_quick_lora : bool = False ,
569
+ ** kwargs
570
+ ):
571
+ ColumnSequenceParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
572
+ if not isinstance (r , int ) or r <= 0 :
573
+ raise ValueError ("Lora rank r should be a positive integer" )
574
+ self .r = r
575
+ self .lora_alpha = lora_alpha
576
+ # Optional dropout
577
+ if lora_dropout > 0.0 :
578
+ self .lora_dropout = nn .Dropout (p = lora_dropout )
579
+ else :
580
+ self .lora_dropout = lambda x : x
581
+ # Mark the weight as unmerged
582
+ self .merged = False
583
+ self .merge_weights = merge_weights
584
+
585
+ # compatible
586
+ self .name = self ._name
587
+
588
+ # Actual trainable parameters
589
+ self .lora_A = self .create_parameter (
590
+ shape = [in_features , r ],
591
+ dtype = self ._dtype ,
592
+ is_bias = False ,
593
+ attr = lora_A_weight_attr ,
594
+ )
595
+ self .lora_A .is_distributed = False
596
+ mark_as_sequence_parallel_parameter (self .lora_A )
597
+
598
+ self .lora_B = self .create_parameter (
599
+ shape = [r , self .output_size_per_partition ],
600
+ dtype = self ._dtype ,
601
+ is_bias = False ,
602
+ attr = paddle .ParamAttr (
603
+ initializer = paddle .nn .initializer .Constant (value = 0.0 ),
604
+ learning_rate = lora_plus_scale ,
605
+ ),
606
+ )
607
+
608
+ self .lora_B .is_distributed = True
609
+ self .lora_B .split_axis = 1
610
+ if not rslora :
611
+ self .scaling = self .lora_alpha / self .r
612
+ else :
613
+ self .scaling = self .lora_alpha / math .sqrt (self .r )
614
+
615
+ # Freezing the pre-trained weight matrix
616
+ self .weight .stop_gradient = True
617
+ self ._use_quick_lora = use_quick_lora and lora_dropout == 0.0
618
+
619
+ @property
620
+ def use_quick_lora (self ):
621
+ # TODO(@gexiao): support qlora
622
+ return False # self._use_quick_lora and self.training and not self.merged
623
+
624
+ def train (self ):
625
+ super ().train ()
626
+ if self .merge_weights and self .merged :
627
+ # Make sure that the weights are not merged
628
+ new_weight = self .weight - self .lora_A @ self .lora_B * self .scaling
629
+ self .weight .set_value (new_weight )
630
+ self .merged = False
631
+
632
+ def eval (self ):
633
+ super ().eval ()
634
+ if self .merge_weights and not self .merged :
635
+ # Merge the weights and mark it
636
+ new_weight = self .weight + self .lora_A @ self .lora_B * self .scaling
637
+ self .weight .set_value (new_weight )
638
+ self .merged = True
639
+
640
+ def forward (self , x : paddle .Tensor ):
641
+ if MC2ColumnSeqParallelCoreLinear is None :
642
+ if self .is_mp :
643
+ input_parallel = AllGatherOp .apply (x )
644
+ else :
645
+ input_parallel = x
646
+ result_mp = self .linear (input_parallel , self .weight , self .bias , name = self ._name )
647
+ else :
648
+ result_mp = MC2ColumnSeqParallelCoreLinear .apply (x , self .weight , self .model_parallel_group )
649
+ if self .bias is not None :
650
+ result_mp += self .bias
651
+
652
+ if not self .merged :
653
+ input_a = self .lora_dropout (x ) @ self .lora_A
654
+ if MC2ColumnSeqParallelCoreLinear is None :
655
+ input_a = AllGatherOp .apply (input_a )
656
+ delta_mp = (input_a @ self .lora_B ) * self .scaling
657
+ else :
658
+ input_a = MC2ColumnSeqParallelCoreLinear .apply (input_a , self .lora_B , self .model_parallel_group )
659
+ delta_mp = input_a * self .scaling
660
+ result_mp += delta_mp
661
+
662
+ if self .gather_output and self .is_mp :
663
+ result = mp_ops ._c_concat (result_mp , group = self .model_parallel_group )
664
+ else :
665
+ result = result_mp
666
+ return result
667
+
668
+ def extra_repr (self ):
669
+ name = f", name={ self .name } " if self .name else ""
670
+ return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
671
+
672
+
431
673
class LoRAMergedLinear (nn .Linear ):
432
674
# LoRA implemented in a dense layer with merged linear weights for q, k, v
433
675
def __init__ (
0 commit comments