Skip to content

Commit be957f0

Browse files
authored
[NPU] Fix sequence parallel lib import (#8760)
1 parent 7a1c439 commit be957f0

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

paddlenlp/peft/lora/lora_layers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,19 @@
2626

2727
from ...transformers import linear_utils
2828

29+
ColumnSequenceParallelLinear = linear_utils.ColumnSequenceParallelLinear
30+
RowSequenceParallelLinear = linear_utils.RowSequenceParallelLinear
31+
2932
try:
3033
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
3134
AllGatherOp,
32-
ColumnSequenceParallelLinear,
3335
ReduceScatterOp,
34-
RowSequenceParallelLinear,
3536
mark_as_sequence_parallel_parameter,
3637
)
3738
except:
3839
AllGatherOp = None
3940
ReduceScatterOp = None
4041
mark_as_sequence_parallel_parameter = None
41-
ColumnSequenceParallelLinear = linear_utils.ColumnSequenceParallelLinear
42-
RowSequenceParallelLinear = linear_utils.RowSequenceParallelLinear
4342

4443

4544
from ...transformers.mc2_parallel_linear import (

0 commit comments

Comments
 (0)