Skip to content

【pir 】modify dy2static Sd and 3. Grounding DINO model #689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddlemix/models/groundingdino/fuse_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.framework import in_dynamic_mode
from paddle.nn.initializer import Constant
from paddlenlp.utils.initializer import constant_, xavier_uniform_

Expand Down Expand Up @@ -183,7 +184,7 @@ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
src_len = key_states.shape[1]
attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1])) # bs*nhead, nimg, ntxt

if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
if in_dynamic_mode() and attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.shape}"
)
Expand Down Expand Up @@ -236,12 +237,12 @@ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
attn_output_v = paddle.bmm(attn_probs_v, value_l_states)
attn_output_l = paddle.bmm(attn_probs_l, value_v_states)

if attn_output_v.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
if in_dynamic_mode() and attn_output_v.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.shape}"
)

if attn_output_l.shape != [bsz * self.num_heads, src_len, self.head_dim]:
if in_dynamic_mode() and attn_output_l.shape != [bsz * self.num_heads, src_len, self.head_dim]:
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.shape}"
)
Expand Down
9 changes: 5 additions & 4 deletions ppdiffusers/ppdiffusers/transformers/clip/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.framework import in_dynamic_mode
from paddlenlp.transformers.activations import ACT2FN
from paddlenlp.transformers.model_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -279,15 +280,15 @@ def forward(
src_len = key_states.shape[1]
attn_weights = paddle.matmul(query_states, key_states, transpose_y=True)

if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
if in_dynamic_mode() and attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
raise ValueError(
f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is"
f" {attn_weights.shape}"
)

# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.shape != [bsz, 1, tgt_len, src_len]:
if in_dynamic_mode() and causal_attention_mask.shape != [bsz, 1, tgt_len, src_len]:
raise ValueError(
f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is"
f" {causal_attention_mask.shape}"
Expand All @@ -296,7 +297,7 @@ def forward(
attn_weights = attn_weights.reshape([bsz * self.num_heads, tgt_len, src_len])

if attention_mask is not None:
if attention_mask.shape != [bsz, 1, tgt_len, src_len]:
if in_dynamic_mode() and attention_mask.shape != [bsz, 1, tgt_len, src_len]:
raise ValueError(
f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}"
)
Expand All @@ -319,7 +320,7 @@ def forward(

attn_output = paddle.matmul(attn_probs, value_states)

if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
if in_dynamic_mode() and attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
raise ValueError(
f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is"
f" {attn_output.shape}"
Expand Down