Skip to content

Commit fa9c067

Browse files
committed
[Distributed] [CustomDevices] adapt lora sp && polish MC2 APIs
1 parent 0dac72d commit fa9c067

File tree

6 files changed

+254
-281
lines changed

6 files changed

+254
-281
lines changed

paddlenlp/peft/lora/lora_layers.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import math
16-
import os
1716
from typing import List, Optional
1817

1918
import paddle
@@ -24,28 +23,22 @@
2423
ColumnParallelLinear,
2524
RowParallelLinear,
2625
)
27-
28-
from paddlenlp.transformers.sequence_parallel_utils import (
26+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
2927
AllGatherOp,
3028
ColumnSequenceParallelLinear,
31-
MC2ColumnSeqParallelLinear,
32-
MC2RowSeqParallelLinear,
3329
ReduceScatterOp,
3430
RowSequenceParallelLinear,
3531
mark_as_sequence_parallel_parameter,
3632
)
3733

38-
from .lora_quick_layers import quick_lora
39-
40-
if "npu" in paddle.device.get_all_custom_device_type():
41-
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
42-
else:
43-
MC2LoRaRowParallelLinear = None
44-
MC2LoRaColumnParallelLinear = None
45-
34+
from paddlenlp.transformers.mc2_parallel_linear import (
35+
MC2ColumnParallelCoreLinear,
36+
MC2ColumnSeqParallelCoreLinear,
37+
MC2RowParallelCoreLinear,
38+
MC2RowSeqParallelCoreLinear,
39+
)
4640

47-
def is_mc2_valid():
48-
return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0"))
41+
from .lora_quick_layers import quick_lora
4942

5043

5144
class LoRALinear(nn.Linear):
@@ -280,16 +273,16 @@ def forward(self, x: paddle.Tensor):
280273
)
281274
else:
282275
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
283-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
284-
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
285-
else:
276+
if MC2RowParallelCoreLinear is None:
286277
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
287278
output = mp_ops._mp_allreduce(
288279
result_mp,
289280
group=self.model_parallel_group,
290281
use_calc_stream=True,
291282
use_model_parallel=True,
292283
)
284+
else:
285+
output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)
293286

294287
if not self.merged:
295288
# x @ A: [bz, in_f/ ws] ===> [bz, r]
@@ -402,21 +395,21 @@ def forward(self, x: paddle.Tensor):
402395
else:
403396
input_mp = x
404397

405-
if not is_mc2_valid():
398+
if MC2RowSeqParallelCoreLinear is None:
406399
output_parallel = self.linear(input_mp, self.weight, name=self._name)
407400
output_ = ReduceScatterOp.apply(output_parallel)
408401
result_mp = output_ + self.bias if self.bias is not None else output_
409402
else:
410-
output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
403+
output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)
411404
result_mp = output_ + self.bias if self.bias is not None else output_
412405

413406
if not self.merged:
414407
input_mp = self.lora_dropout(input_mp)
415-
if not is_mc2_valid():
408+
if MC2RowSeqParallelCoreLinear is None:
416409
input_mp = input_mp @ self.lora_A
417410
input_mp = ReduceScatterOp.apply(input_mp)
418411
else:
419-
input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
412+
input_mp = MC2RowSeqParallelCoreLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
420413
delta_mp = (input_mp @ self.lora_B) * self.scaling
421414
result_mp += delta_mp
422415
return result_mp
@@ -528,21 +521,21 @@ def forward(self, input: paddle.Tensor):
528521
world_size=self.world_size,
529522
)
530523
else:
531-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
532-
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
533-
result_mp = res_mp + self.bias
534-
else:
524+
if MC2ColumnParallelCoreLinear is None:
535525
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
536526
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
527+
else:
528+
res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group)
529+
result_mp = res_mp + self.bias
537530

538531
if not self.merged:
539532
input_a = self.lora_dropout(input) @ self.lora_A
540-
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
541-
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
542-
delta_mp = tmp * self.scaling
543-
else:
533+
if MC2ColumnParallelCoreLinear is None:
544534
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
545535
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
536+
else:
537+
tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
538+
delta_mp = tmp * self.scaling
546539
result_mp += delta_mp
547540

548541
if self.gather_output and self.is_mp:
@@ -641,24 +634,24 @@ def eval(self):
641634
self.merged = True
642635

643636
def forward(self, x: paddle.Tensor):
644-
if not is_mc2_valid():
637+
if MC2ColumnSeqParallelCoreLinear is None:
645638
if self.is_mp:
646639
input_parallel = AllGatherOp.apply(x)
647640
else:
648641
input_parallel = x
649642
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
650643
else:
651-
result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group)
644+
result_mp = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group)
652645
if self.bias is not None:
653646
result_mp += self.bias
654647

655648
if not self.merged:
656649
input_a = self.lora_dropout(x) @ self.lora_A
657-
if not is_mc2_valid():
650+
if MC2ColumnSeqParallelCoreLinear is None:
658651
input_a = AllGatherOp.apply(input_a)
659652
delta_mp = (input_a @ self.lora_B) * self.scaling
660653
else:
661-
input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
654+
input_a = MC2ColumnSeqParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
662655
delta_mp = input_a * self.scaling
663656
result_mp += delta_mp
664657

paddlenlp/peft/lora/lora_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
PipelineLayer,
3232
RowParallelLinear,
3333
)
34-
35-
from paddlenlp.transformers.sequence_parallel_utils import (
34+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
3635
ColumnSequenceParallelLinear,
3736
RowSequenceParallelLinear,
3837
)

paddlenlp/peft/lora/mc2_lora_npu.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

paddlenlp/transformers/llama/modeling.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def swiglu(x, y=None):
6262
init_name_mappings,
6363
)
6464
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
65+
from paddlenlp.transformers.mc2_parallel_linear import (
66+
MC2ColumnSeqParallelLinear,
67+
MC2RowSeqParallelLinear,
68+
)
6569
from paddlenlp.transformers.model_outputs import (
6670
BaseModelOutputWithPastAndCrossAttentions,
6771
CausalLMOutputWithCrossAttentions,
@@ -96,13 +100,6 @@ def swiglu(x, y=None):
96100
]
97101

98102

99-
def is_mc2_valid():
100-
current_device = get_env_device()
101-
if current_device == "npu":
102-
return True
103-
return False
104-
105-
106103
def _get_interleave(n):
107104
def _get_interleave_power_of_2(n):
108105
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -574,12 +571,7 @@ def __init__(self, config):
574571
self.fuse_attention_ffn = config.fuse_attention_ffn
575572

576573
if config.sequence_parallel:
577-
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
578-
from paddlenlp.transformers.mc2_seqence_parallel_linear import (
579-
MC2ColumnSeqParallelLinear,
580-
MC2RowSeqParallelLinear,
581-
)
582-
574+
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
583575
ColumnParallelLinear = MC2ColumnSeqParallelLinear
584576
RowParallelLinear = MC2RowSeqParallelLinear
585577
else:
@@ -697,12 +689,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
697689
self.use_fused_rope = False
698690

699691
if config.sequence_parallel:
700-
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
701-
from paddlenlp.transformers.mc2_seqence_parallel_linear import (
702-
MC2ColumnSeqParallelLinear,
703-
MC2RowSeqParallelLinear,
704-
)
705-
692+
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
706693
ColumnParallelLinear = MC2ColumnSeqParallelLinear
707694
RowParallelLinear = MC2RowSeqParallelLinear
708695
else:

0 commit comments

Comments
 (0)