|
26 | 26 | from paddle.distributed.fleet.utils import recompute
|
27 | 27 | from paddle.utils import try_import
|
28 | 28 |
|
| 29 | +try: |
| 30 | + from paddle.incubate.nn.functional import swiglu |
| 31 | +except ImportError: |
| 32 | + |
| 33 | + def swiglu(x, y=None): |
| 34 | + if y is None: |
| 35 | + x, y = paddle.chunk(x, chunks=2, axis=-1) |
| 36 | + return F.silu(x) * y |
| 37 | + |
| 38 | + |
29 | 39 | from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
|
30 | 40 | from paddlenlp.transformers.model_outputs import (
|
31 | 41 | BaseModelOutputWithPast,
|
|
35 | 45 | from paddlenlp.utils.log import logger
|
36 | 46 |
|
37 | 47 | from ...utils.converter import StateDictNameMapping, init_name_mappings
|
| 48 | +from .. import linear_utils |
38 | 49 | from ..model_outputs import ModelOutput
|
39 | 50 | from .configuration import QWenConfig
|
40 | 51 |
|
@@ -329,37 +340,60 @@ class QWenMLP(nn.Layer):
|
329 | 340 | def __init__(self, config):
|
330 | 341 | super().__init__()
|
331 | 342 | ff_dim_in = config.intermediate_size // 2
|
| 343 | + self.fuse_attention_ffn = config.fuse_attention_ffn |
| 344 | + |
| 345 | + if config.sequence_parallel: |
| 346 | + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear |
| 347 | + RowParallelLinear = linear_utils.RowSequenceParallelLinear |
| 348 | + else: |
| 349 | + ColumnParallelLinear = linear_utils.ColumnParallelLinear |
| 350 | + RowParallelLinear = linear_utils.RowParallelLinear |
| 351 | + |
332 | 352 | if config.tensor_parallel_degree > 1:
|
333 |
| - self.w1 = mpu.ColumnParallelLinear( |
334 |
| - config.hidden_size, |
335 |
| - ff_dim_in, |
336 |
| - gather_output=False, |
337 |
| - has_bias=False, |
338 |
| - ) |
339 |
| - self.w2 = mpu.ColumnParallelLinear( |
340 |
| - config.hidden_size, |
341 |
| - ff_dim_in, |
342 |
| - gather_output=False, |
343 |
| - has_bias=False, |
344 |
| - ) |
345 |
| - self.c_proj = mpu.RowParallelLinear( |
| 353 | + if self.fuse_attention_ffn: |
| 354 | + self.gate_up_fused_proj = ColumnParallelLinear( |
| 355 | + config.hidden_size, |
| 356 | + ff_dim_in * 2, |
| 357 | + gather_output=False, |
| 358 | + has_bias=False, |
| 359 | + ) |
| 360 | + else: |
| 361 | + self.w1 = ColumnParallelLinear( |
| 362 | + config.hidden_size, |
| 363 | + ff_dim_in, |
| 364 | + gather_output=False, |
| 365 | + has_bias=False, |
| 366 | + ) |
| 367 | + self.w2 = ColumnParallelLinear( |
| 368 | + config.hidden_size, |
| 369 | + ff_dim_in, |
| 370 | + gather_output=False, |
| 371 | + has_bias=False, |
| 372 | + ) |
| 373 | + self.c_proj = RowParallelLinear( |
346 | 374 | ff_dim_in,
|
347 | 375 | config.hidden_size,
|
348 | 376 | input_is_parallel=True,
|
349 | 377 | has_bias=False,
|
350 | 378 | )
|
351 | 379 | else:
|
352 |
| - self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) |
353 |
| - self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) |
| 380 | + if self.fuse_attention_ffn: |
| 381 | + self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias) |
| 382 | + else: |
| 383 | + self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) |
| 384 | + self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) |
354 | 385 | self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)
|
355 | 386 |
|
356 | 387 | def forward(self, hidden_states):
|
357 | 388 | # up
|
358 |
| - a1 = self.w1(hidden_states) |
359 |
| - # gate |
360 |
| - a2 = self.w2(hidden_states) |
361 |
| - intermediate_parallel = a1 * F.silu(a2) |
362 |
| - # down |
| 389 | + # a1 = self.w1(hidden_states) |
| 390 | + # # gate |
| 391 | + # a2 = self.w2(hidden_states) |
| 392 | + # intermediate_parallel = a1 * F.silu(a2) |
| 393 | + if self.fuse_attention_ffn: |
| 394 | + intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states)) |
| 395 | + else: |
| 396 | + intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states)) |
363 | 397 | output = self.c_proj(intermediate_parallel)
|
364 | 398 | return output
|
365 | 399 |
|
|
0 commit comments