Skip to content

Commit 18b5946

Browse files
committed
add fuse_attention_ffn support for qwen
1 parent 0087c4a commit 18b5946

File tree

2 files changed

+58
-20
lines changed

2 files changed

+58
-20
lines changed

paddlenlp/transformers/qwen/configuration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(
4343
use_flash_attention=False,
4444
use_fused_rms_norm=False,
4545
use_fused_rope=False,
46+
fuse_attention_ffn=False,
47+
sequence_parallel=False,
4648
intermediate_size=22016,
4749
tensor_parallel_output=True,
4850
no_bias=True,
@@ -77,6 +79,8 @@ def __init__(
7779
self.use_flash_attention = use_flash_attention
7880
self.use_fused_rms_norm = use_fused_rms_norm
7981
self.use_fused_rope = use_fused_rope
82+
self.fuse_attention_ffn = fuse_attention_ffn
83+
self.sequence_parallel = sequence_parallel
8084
self.no_bias = no_bias
8185

8286
self.long_sequence_strategy_type = long_sequence_strategy_type

paddlenlp/transformers/qwen/modeling.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@
2626
from paddle.distributed.fleet.utils import recompute
2727
from paddle.utils import try_import
2828

29+
try:
30+
from paddle.incubate.nn.functional import swiglu
31+
except ImportError:
32+
33+
def swiglu(x, y=None):
34+
if y is None:
35+
x, y = paddle.chunk(x, chunks=2, axis=-1)
36+
return F.silu(x) * y
37+
38+
2939
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
3040
from paddlenlp.transformers.model_outputs import (
3141
BaseModelOutputWithPast,
@@ -35,6 +45,7 @@
3545
from paddlenlp.utils.log import logger
3646

3747
from ...utils.converter import StateDictNameMapping, init_name_mappings
48+
from .. import linear_utils
3849
from ..model_outputs import ModelOutput
3950
from .configuration import QWenConfig
4051

@@ -329,37 +340,60 @@ class QWenMLP(nn.Layer):
329340
def __init__(self, config):
330341
super().__init__()
331342
ff_dim_in = config.intermediate_size // 2
343+
self.fuse_attention_ffn = config.fuse_attention_ffn
344+
345+
if config.sequence_parallel:
346+
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
347+
RowParallelLinear = linear_utils.RowSequenceParallelLinear
348+
else:
349+
ColumnParallelLinear = linear_utils.ColumnParallelLinear
350+
RowParallelLinear = linear_utils.RowParallelLinear
351+
332352
if config.tensor_parallel_degree > 1:
333-
self.w1 = mpu.ColumnParallelLinear(
334-
config.hidden_size,
335-
ff_dim_in,
336-
gather_output=False,
337-
has_bias=False,
338-
)
339-
self.w2 = mpu.ColumnParallelLinear(
340-
config.hidden_size,
341-
ff_dim_in,
342-
gather_output=False,
343-
has_bias=False,
344-
)
345-
self.c_proj = mpu.RowParallelLinear(
353+
if self.fuse_attention_ffn:
354+
self.gate_up_fused_proj = ColumnParallelLinear(
355+
config.hidden_size,
356+
ff_dim_in * 2,
357+
gather_output=False,
358+
has_bias=False,
359+
)
360+
else:
361+
self.w1 = ColumnParallelLinear(
362+
config.hidden_size,
363+
ff_dim_in,
364+
gather_output=False,
365+
has_bias=False,
366+
)
367+
self.w2 = ColumnParallelLinear(
368+
config.hidden_size,
369+
ff_dim_in,
370+
gather_output=False,
371+
has_bias=False,
372+
)
373+
self.c_proj = RowParallelLinear(
346374
ff_dim_in,
347375
config.hidden_size,
348376
input_is_parallel=True,
349377
has_bias=False,
350378
)
351379
else:
352-
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
353-
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
380+
if self.fuse_attention_ffn:
381+
self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias)
382+
else:
383+
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
384+
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
354385
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)
355386

356387
def forward(self, hidden_states):
357388
# up
358-
a1 = self.w1(hidden_states)
359-
# gate
360-
a2 = self.w2(hidden_states)
361-
intermediate_parallel = a1 * F.silu(a2)
362-
# down
389+
# a1 = self.w1(hidden_states)
390+
# # gate
391+
# a2 = self.w2(hidden_states)
392+
# intermediate_parallel = a1 * F.silu(a2)
393+
if self.fuse_attention_ffn:
394+
intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states))
395+
else:
396+
intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states))
363397
output = self.c_proj(intermediate_parallel)
364398
return output
365399

0 commit comments

Comments
 (0)