Skip to content

Commit 0f428bb

Browse files
SylarTiaNIIJunnYu
authored andcommitted
[Distributed] [CustomDevices] Adapt SP on lora && polish MC2 APIs (#8303)
* [Distributed] adapt sequence parallel on LoRA (#8235) * [Distributed] [CustomDevices] adapt lora sp && polish MC2 APIs
1 parent 871070d commit 0f428bb

File tree

6 files changed

+572
-272
lines changed

6 files changed

+572
-272
lines changed

paddlenlp/peft/lora/lora_layers.py

Lines changed: 260 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import math
16-
import os
1716
from typing import List, Optional
1817

1918
import paddle
@@ -25,13 +24,25 @@
2524
RowParallelLinear,
2625
)
2726

28-
from .lora_quick_layers import quick_lora
27+
try:
28+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
29+
AllGatherOp,
30+
ColumnSequenceParallelLinear,
31+
ReduceScatterOp,
32+
RowSequenceParallelLinear,
33+
mark_as_sequence_parallel_parameter,
34+
)
35+
except:
36+
pass
37+
38+
from paddlenlp.transformers.mc2_parallel_linear import (
39+
MC2ColumnParallelCoreLinear,
40+
MC2ColumnSeqParallelCoreLinear,
41+
MC2RowParallelCoreLinear,
42+
MC2RowSeqParallelCoreLinear,
43+
)
2944

30-
if "npu" in paddle.device.get_all_custom_device_type():
31-
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
32-
else:
33-
MC2LoRaRowParallelLinear = None
34-
MC2LoRaColumnParallelLinear = None
45+
from .lora_quick_layers import quick_lora
3546

3647

3748
class LoRALinear(nn.Linear):
@@ -266,16 +277,16 @@ def forward(self, x: paddle.Tensor):
266277
)
267278
else:
268279
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
269-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
270-
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
271-
else:
280+
if MC2RowParallelCoreLinear is None:
272281
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
273282
output = mp_ops._mp_allreduce(
274283
result_mp,
275284
group=self.model_parallel_group,
276285
use_calc_stream=True,
277286
use_model_parallel=True,
278287
)
288+
else:
289+
output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)
279290

280291
if not self.merged:
281292
# x @ A: [bz, in_f/ ws] ===> [bz, r]
@@ -298,6 +309,120 @@ def extra_repr(self):
298309
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
299310

300311

312+
class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
313+
def __init__(
314+
self,
315+
in_features: int,
316+
out_features: int,
317+
r: int = 0,
318+
lora_alpha: int = 1,
319+
lora_dropout: float = 0.0,
320+
rslora: bool = False,
321+
lora_plus_scale: float = 1.0,
322+
merge_weights: bool = True,
323+
use_quick_lora: bool = False,
324+
**kwargs
325+
):
326+
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
327+
if not isinstance(r, int) or r <= 0:
328+
raise ValueError("Lora rank r should be a positive integer")
329+
self.r = r
330+
self.lora_alpha = lora_alpha
331+
# Optional dropout
332+
if lora_dropout > 0.0:
333+
self.lora_dropout = nn.Dropout(p=lora_dropout)
334+
else:
335+
self.lora_dropout = lambda x: x
336+
# Mark the weight as unmerged
337+
self.merged = False
338+
self.merge_weights = merge_weights
339+
340+
# compatible
341+
self.name = self._name
342+
343+
# Actual trainable parameters
344+
self.lora_A = self.create_parameter(
345+
shape=[self.input_size_per_partition, r],
346+
dtype=self._dtype,
347+
is_bias=False,
348+
attr=paddle.ParamAttr(
349+
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
350+
),
351+
)
352+
self.lora_B = self.create_parameter(
353+
shape=[r, self.out_features],
354+
dtype=self._dtype,
355+
is_bias=False,
356+
attr=paddle.ParamAttr(
357+
initializer=paddle.nn.initializer.Constant(value=0.0),
358+
learning_rate=lora_plus_scale,
359+
),
360+
)
361+
362+
self.lora_A.is_distributed = True
363+
self.lora_A.split_axis = 0
364+
self.lora_B.is_distributed = False
365+
mark_as_sequence_parallel_parameter(self.lora_B)
366+
if not rslora:
367+
self.scaling = self.lora_alpha / self.r
368+
else:
369+
self.scaling = self.lora_alpha / math.sqrt(self.r)
370+
371+
# Freezing the pre-trained weight matrix
372+
self.weight.stop_gradient = True
373+
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
374+
375+
@property
376+
def use_quick_lora(self):
377+
# TODO(@gexiao): support qlora
378+
return False # self._use_quick_lora and self.training and not self.merged
379+
380+
def train(self):
381+
super().train()
382+
if self.merge_weights and self.merged:
383+
# Make sure that the weights are not merged
384+
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
385+
self.weight.set_value(new_weight)
386+
self.merged = False
387+
388+
def eval(self):
389+
super().eval()
390+
if self.merge_weights and not self.merged:
391+
# Merge the weights and mark it
392+
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
393+
self.weight.set_value(new_weight)
394+
self.merged = True
395+
396+
def forward(self, x: paddle.Tensor):
397+
if not self.input_is_parallel:
398+
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
399+
else:
400+
input_mp = x
401+
402+
if MC2RowSeqParallelCoreLinear is None:
403+
output_parallel = self.linear(input_mp, self.weight, name=self._name)
404+
output_ = ReduceScatterOp.apply(output_parallel)
405+
result_mp = output_ + self.bias if self.bias is not None else output_
406+
else:
407+
output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)
408+
result_mp = output_ + self.bias if self.bias is not None else output_
409+
410+
if not self.merged:
411+
input_mp = self.lora_dropout(input_mp)
412+
if MC2RowSeqParallelCoreLinear is None:
413+
input_mp = input_mp @ self.lora_A
414+
input_mp = ReduceScatterOp.apply(input_mp)
415+
else:
416+
input_mp = MC2RowSeqParallelCoreLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
417+
delta_mp = (input_mp @ self.lora_B) * self.scaling
418+
result_mp += delta_mp
419+
return result_mp
420+
421+
def extra_repr(self):
422+
name = f", name={self.name}" if self.name else ""
423+
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
424+
425+
301426
class ColumnParallelLoRALinear(ColumnParallelLinear):
302427
def __init__(
303428
self,
@@ -400,21 +525,21 @@ def forward(self, input: paddle.Tensor):
400525
world_size=self.world_size,
401526
)
402527
else:
403-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
404-
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
405-
result_mp = res_mp + self.bias
406-
else:
528+
if MC2ColumnParallelCoreLinear is None:
407529
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
408530
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
531+
else:
532+
res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group)
533+
result_mp = res_mp + self.bias
409534

410535
if not self.merged:
411536
input_a = self.lora_dropout(input) @ self.lora_A
412-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
413-
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
414-
delta_mp = tmp * self.scaling
415-
else:
537+
if MC2ColumnParallelCoreLinear is None:
416538
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
417539
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
540+
else:
541+
tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
542+
delta_mp = tmp * self.scaling
418543
result_mp += delta_mp
419544

420545
if self.gather_output and self.is_mp:
@@ -428,6 +553,123 @@ def extra_repr(self):
428553
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
429554

430555

556+
class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
557+
def __init__(
558+
self,
559+
in_features: int,
560+
out_features: int,
561+
r: int = 0,
562+
lora_alpha: int = 1,
563+
lora_dropout: float = 0.0,
564+
rslora: bool = False,
565+
lora_plus_scale: float = 1.0,
566+
merge_weights: bool = True,
567+
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
568+
use_quick_lora: bool = False,
569+
**kwargs
570+
):
571+
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
572+
if not isinstance(r, int) or r <= 0:
573+
raise ValueError("Lora rank r should be a positive integer")
574+
self.r = r
575+
self.lora_alpha = lora_alpha
576+
# Optional dropout
577+
if lora_dropout > 0.0:
578+
self.lora_dropout = nn.Dropout(p=lora_dropout)
579+
else:
580+
self.lora_dropout = lambda x: x
581+
# Mark the weight as unmerged
582+
self.merged = False
583+
self.merge_weights = merge_weights
584+
585+
# compatible
586+
self.name = self._name
587+
588+
# Actual trainable parameters
589+
self.lora_A = self.create_parameter(
590+
shape=[in_features, r],
591+
dtype=self._dtype,
592+
is_bias=False,
593+
attr=lora_A_weight_attr,
594+
)
595+
self.lora_A.is_distributed = False
596+
mark_as_sequence_parallel_parameter(self.lora_A)
597+
598+
self.lora_B = self.create_parameter(
599+
shape=[r, self.output_size_per_partition],
600+
dtype=self._dtype,
601+
is_bias=False,
602+
attr=paddle.ParamAttr(
603+
initializer=paddle.nn.initializer.Constant(value=0.0),
604+
learning_rate=lora_plus_scale,
605+
),
606+
)
607+
608+
self.lora_B.is_distributed = True
609+
self.lora_B.split_axis = 1
610+
if not rslora:
611+
self.scaling = self.lora_alpha / self.r
612+
else:
613+
self.scaling = self.lora_alpha / math.sqrt(self.r)
614+
615+
# Freezing the pre-trained weight matrix
616+
self.weight.stop_gradient = True
617+
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
618+
619+
@property
620+
def use_quick_lora(self):
621+
# TODO(@gexiao): support qlora
622+
return False # self._use_quick_lora and self.training and not self.merged
623+
624+
def train(self):
625+
super().train()
626+
if self.merge_weights and self.merged:
627+
# Make sure that the weights are not merged
628+
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
629+
self.weight.set_value(new_weight)
630+
self.merged = False
631+
632+
def eval(self):
633+
super().eval()
634+
if self.merge_weights and not self.merged:
635+
# Merge the weights and mark it
636+
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
637+
self.weight.set_value(new_weight)
638+
self.merged = True
639+
640+
def forward(self, x: paddle.Tensor):
641+
if MC2ColumnSeqParallelCoreLinear is None:
642+
if self.is_mp:
643+
input_parallel = AllGatherOp.apply(x)
644+
else:
645+
input_parallel = x
646+
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
647+
else:
648+
result_mp = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group)
649+
if self.bias is not None:
650+
result_mp += self.bias
651+
652+
if not self.merged:
653+
input_a = self.lora_dropout(x) @ self.lora_A
654+
if MC2ColumnSeqParallelCoreLinear is None:
655+
input_a = AllGatherOp.apply(input_a)
656+
delta_mp = (input_a @ self.lora_B) * self.scaling
657+
else:
658+
input_a = MC2ColumnSeqParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
659+
delta_mp = input_a * self.scaling
660+
result_mp += delta_mp
661+
662+
if self.gather_output and self.is_mp:
663+
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
664+
else:
665+
result = result_mp
666+
return result
667+
668+
def extra_repr(self):
669+
name = f", name={self.name}" if self.name else ""
670+
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
671+
672+
431673
class LoRAMergedLinear(nn.Linear):
432674
# LoRA implemented in a dense layer with merged linear weights for q, k, v
433675
def __init__(

0 commit comments

Comments
 (0)