|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import math
|
16 |
| -import os |
17 | 16 | from typing import List, Optional
|
18 | 17 |
|
19 | 18 | import paddle
|
|
24 | 23 | ColumnParallelLinear,
|
25 | 24 | RowParallelLinear,
|
26 | 25 | )
|
27 |
| - |
28 |
| -from paddlenlp.transformers.sequence_parallel_utils import ( |
| 26 | +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
29 | 27 | AllGatherOp,
|
30 | 28 | ColumnSequenceParallelLinear,
|
31 |
| - MC2ColumnSeqParallelLinear, |
32 |
| - MC2RowSeqParallelLinear, |
33 | 29 | ReduceScatterOp,
|
34 | 30 | RowSequenceParallelLinear,
|
35 | 31 | mark_as_sequence_parallel_parameter,
|
36 | 32 | )
|
37 | 33 |
|
38 |
| -from .lora_quick_layers import quick_lora |
39 |
| - |
40 |
| -if "npu" in paddle.device.get_all_custom_device_type(): |
41 |
| - from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear |
42 |
| -else: |
43 |
| - MC2LoRaRowParallelLinear = None |
44 |
| - MC2LoRaColumnParallelLinear = None |
45 |
| - |
| 34 | +from paddlenlp.transformers.mc2_parallel_linear import ( |
| 35 | + MC2ColumnParallelCoreLinear, |
| 36 | + MC2ColumnSeqParallelCoreLinear, |
| 37 | + MC2RowParallelCoreLinear, |
| 38 | + MC2RowSeqParallelCoreLinear, |
| 39 | +) |
46 | 40 |
|
47 |
| -def is_mc2_valid(): |
48 |
| - return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")) |
| 41 | +from .lora_quick_layers import quick_lora |
49 | 42 |
|
50 | 43 |
|
51 | 44 | class LoRALinear(nn.Linear):
|
@@ -280,16 +273,16 @@ def forward(self, x: paddle.Tensor):
|
280 | 273 | )
|
281 | 274 | else:
|
282 | 275 | # x @ W : [bz, in_f / ws] ===> [bz, out_f]
|
283 |
| - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): |
284 |
| - output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) |
285 |
| - else: |
| 276 | + if MC2RowParallelCoreLinear is None: |
286 | 277 | result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
|
287 | 278 | output = mp_ops._mp_allreduce(
|
288 | 279 | result_mp,
|
289 | 280 | group=self.model_parallel_group,
|
290 | 281 | use_calc_stream=True,
|
291 | 282 | use_model_parallel=True,
|
292 | 283 | )
|
| 284 | + else: |
| 285 | + output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) |
293 | 286 |
|
294 | 287 | if not self.merged:
|
295 | 288 | # x @ A: [bz, in_f/ ws] ===> [bz, r]
|
@@ -402,21 +395,21 @@ def forward(self, x: paddle.Tensor):
|
402 | 395 | else:
|
403 | 396 | input_mp = x
|
404 | 397 |
|
405 |
| - if not is_mc2_valid(): |
| 398 | + if MC2RowSeqParallelCoreLinear is None: |
406 | 399 | output_parallel = self.linear(input_mp, self.weight, name=self._name)
|
407 | 400 | output_ = ReduceScatterOp.apply(output_parallel)
|
408 | 401 | result_mp = output_ + self.bias if self.bias is not None else output_
|
409 | 402 | else:
|
410 |
| - output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) |
| 403 | + output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) |
411 | 404 | result_mp = output_ + self.bias if self.bias is not None else output_
|
412 | 405 |
|
413 | 406 | if not self.merged:
|
414 | 407 | input_mp = self.lora_dropout(input_mp)
|
415 |
| - if not is_mc2_valid(): |
| 408 | + if MC2RowSeqParallelCoreLinear is None: |
416 | 409 | input_mp = input_mp @ self.lora_A
|
417 | 410 | input_mp = ReduceScatterOp.apply(input_mp)
|
418 | 411 | else:
|
419 |
| - input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group) |
| 412 | + input_mp = MC2RowSeqParallelCoreLinear.apply(input_mp, self.lora_A, self.model_parallel_group) |
420 | 413 | delta_mp = (input_mp @ self.lora_B) * self.scaling
|
421 | 414 | result_mp += delta_mp
|
422 | 415 | return result_mp
|
@@ -528,21 +521,21 @@ def forward(self, input: paddle.Tensor):
|
528 | 521 | world_size=self.world_size,
|
529 | 522 | )
|
530 | 523 | else:
|
531 |
| - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): |
532 |
| - res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group) |
533 |
| - result_mp = res_mp + self.bias |
534 |
| - else: |
| 524 | + if MC2ColumnParallelCoreLinear is None: |
535 | 525 | input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
|
536 | 526 | result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
|
| 527 | + else: |
| 528 | + res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group) |
| 529 | + result_mp = res_mp + self.bias |
537 | 530 |
|
538 | 531 | if not self.merged:
|
539 | 532 | input_a = self.lora_dropout(input) @ self.lora_A
|
540 |
| - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): |
541 |
| - tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
542 |
| - delta_mp = tmp * self.scaling |
543 |
| - else: |
| 533 | + if MC2ColumnParallelCoreLinear is None: |
544 | 534 | input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
|
545 | 535 | delta_mp = (input_a_mp @ self.lora_B) * self.scaling
|
| 536 | + else: |
| 537 | + tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
| 538 | + delta_mp = tmp * self.scaling |
546 | 539 | result_mp += delta_mp
|
547 | 540 |
|
548 | 541 | if self.gather_output and self.is_mp:
|
@@ -641,24 +634,24 @@ def eval(self):
|
641 | 634 | self.merged = True
|
642 | 635 |
|
643 | 636 | def forward(self, x: paddle.Tensor):
|
644 |
| - if not is_mc2_valid(): |
| 637 | + if MC2ColumnSeqParallelCoreLinear is None: |
645 | 638 | if self.is_mp:
|
646 | 639 | input_parallel = AllGatherOp.apply(x)
|
647 | 640 | else:
|
648 | 641 | input_parallel = x
|
649 | 642 | result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
|
650 | 643 | else:
|
651 |
| - result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group) |
| 644 | + result_mp = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group) |
652 | 645 | if self.bias is not None:
|
653 | 646 | result_mp += self.bias
|
654 | 647 |
|
655 | 648 | if not self.merged:
|
656 | 649 | input_a = self.lora_dropout(x) @ self.lora_A
|
657 |
| - if not is_mc2_valid(): |
| 650 | + if MC2ColumnSeqParallelCoreLinear is None: |
658 | 651 | input_a = AllGatherOp.apply(input_a)
|
659 | 652 | delta_mp = (input_a @ self.lora_B) * self.scaling
|
660 | 653 | else:
|
661 |
| - input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
| 654 | + input_a = MC2ColumnSeqParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
662 | 655 | delta_mp = input_a * self.scaling
|
663 | 656 | result_mp += delta_mp
|
664 | 657 |
|
|
0 commit comments