Skip to content

Commit 1b9c230

Browse files
committed
[XPU] Qwen2_vl optimization for xpu
1 parent 4212497 commit 1b9c230

File tree

2 files changed

+62
-8
lines changed

2 files changed

+62
-8
lines changed

paddlemix/examples/qwen2_vl/qwen2vl_finetune.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
MegaByte = 2**20
5353
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
5454
from paddlenlp.utils.log import logger
55+
from paddlenlp.utils.tools import get_env_device
5556

5657
# Set constants for image processing and logging
5758
IGNORE_INDEX = -100
@@ -360,7 +361,7 @@ def pure_text_get_item(self, data_item):
360361
attention_mask=attention_mask,
361362
images=[],
362363
)
363-
364+
364365
return ret
365366

366367
def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
@@ -473,11 +474,11 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
473474
batch_videos.extend(videos)
474475
batch_imglens.append(len(images))
475476
batch_vidlens.append(len(videos))
476-
batch_input_ids.append(feature["input_ids"])
477+
batch_input_ids.append(feature["input_ids"])
477478

478479
if (
479480
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
480-
):
481+
):
481482
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
482483
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
483484
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
@@ -679,6 +680,16 @@ def main():
679680
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
680681
)
681682

683+
if get_env_device == "xpu" and training_args.gradient_accumulation_steps > 1:
684+
try:
685+
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
686+
687+
LinearConfig.enable_accumulate_steps_opt()
688+
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
689+
except ImportError:
690+
# It's OK, not use accumulate_steps optimization
691+
pass
692+
682693
# Load model
683694
if "npu" in paddle.get_device():
684695
is_bfloat16_supported = True

paddlemix/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from paddlenlp.transformers.linear_utils import Linear
3838
from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast, ModelOutput
3939
from paddlenlp.transformers.model_utils import PretrainedModel
40+
from paddlenlp.utils.tools import get_env_device
4041

4142
from paddlemix.models.flash_attn_utils import (
4243
create_attention_module,
@@ -48,6 +49,11 @@
4849
from .bert_padding import index_first_axis, pad_input, unpad_input
4950
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
5051

52+
try:
53+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
54+
except ImportError:
55+
fused_rotary_position_embedding = None
56+
5157
logger = logging.get_logger(__name__)
5258

5359
flash_attn_func, flash_attn_varlen_func = has_flash_attn_func()
@@ -407,7 +413,12 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) ->
407413
sin = freqs.sin()
408414
cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
409415
sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
410-
output = tensor * cos + rotate_half(tensor) * sin
416+
if get_env_device() == "xpu" and fused_rotary_position_embedding is not None:
417+
output, _, _ = fused_rotary_position_embedding(
418+
tensor, sin=sin, cos=cos, use_neox_rotary_style=False
419+
)
420+
else:
421+
output = tensor * cos + rotate_half(tensor) * sin
411422
output = paddle.cast(output, orig_dtype)
412423
return output
413424

@@ -463,6 +474,12 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N
463474
nn.GELU(),
464475
nn.Linear(self.hidden_size, dim),
465476
)
477+
if get_env_device() == "xpu":
478+
self.mlp = nn.Sequential(
479+
Linear(self.hidden_size, self.hidden_size),
480+
nn.GELU(),
481+
Linear(self.hidden_size, dim),
482+
)
466483

467484
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
468485
x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
@@ -475,6 +492,9 @@ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
475492
self.fc1 = nn.Linear(dim, hidden_dim)
476493
self.act = ACT2FN[hidden_act]
477494
self.fc2 = nn.Linear(hidden_dim, dim)
495+
if get_env_device() == "xpu":
496+
self.fc1 = Linear(dim, hidden_dim)
497+
self.fc2 = Linear(hidden_dim, dim)
478498

479499
def forward(self, x) -> paddle.Tensor:
480500
return self.fc2(self.act(self.fc1(x)))
@@ -486,6 +506,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
486506
self.num_heads = num_heads
487507
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
488508
self.proj = nn.Linear(dim, dim)
509+
if get_env_device() == "xpu":
510+
self.qkv = Linear(dim, dim * 3, bias_attr=True)
511+
self.proj = Linear(dim, dim)
489512
self.head_dim = dim // num_heads # must added
490513

491514
def forward(
@@ -525,6 +548,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
525548
self.num_heads = num_heads
526549
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
527550
self.proj = nn.Linear(dim, dim)
551+
if get_env_device() == "xpu":
552+
self.qkv = Linear(dim, dim * 3, bias_attr=True)
553+
self.proj = Linear(dim, dim)
528554
self.head_dim = dim // num_heads # must added
529555

530556
def forward(
@@ -657,6 +683,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
657683
self.variance_epsilon = eps
658684

659685
def forward(self, hidden_states):
686+
if get_env_device() == "xpu":
687+
try:
688+
import paddle_xpu_nn # noqa: F821
689+
690+
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
691+
except ImportError:
692+
raise NotImplementedError(
693+
f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature"
694+
)
660695
if paddle.in_dynamic_mode():
661696
with paddle.amp.auto_cast(False):
662697
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
@@ -1193,7 +1228,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel):
11931228

11941229
def _init_weights(self, layer):
11951230
std = 0.2
1196-
if isinstance(layer, (nn.Linear, nn.Conv3D)):
1231+
if isinstance(layer, (nn.Linear, nn.Conv3D, Linear)):
11971232
nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
11981233
if layer.bias is not None:
11991234
nn.initializer.Constant(0.0)(layer.bias)
@@ -1558,6 +1593,9 @@ def __init__(self, config, embedding_weights=None, transpose_y=False):
15581593
shape=[config.hidden_size, vocab_size],
15591594
dtype=paddle.get_default_dtype(),
15601595
)
1596+
if get_env_device() == "xpu":
1597+
import paddle_xpu.layers.nn.linear as xpu_linear
1598+
self.xpu_parallel_matmul = xpu_linear.parallel_matmul()
15611599

15621600
# Must set distributed attr for Tensor Parallel !
15631601
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
@@ -1573,9 +1611,14 @@ def forward(self, hidden_states, tensor_parallel_output=None):
15731611
if self.weight.dtype != hidden_states.dtype:
15741612
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
15751613

1576-
logits = parallel_matmul(
1577-
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
1578-
)
1614+
if get_env_device() == "xpu":
1615+
logits = self.xpu_parallel_matmul.forward(
1616+
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
1617+
)
1618+
else:
1619+
logits = parallel_matmul(
1620+
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
1621+
)
15791622
return logits
15801623

15811624

0 commit comments

Comments
 (0)