Skip to content

Commit 54b7d5a

Browse files
committed
add sd3 test
1 parent e869c1c commit 54b7d5a

File tree

3 files changed

+84
-5
lines changed

3 files changed

+84
-5
lines changed

ppdiffusers/examples/dreambooth/README_sd3.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ python train_dreambooth_sd3.py \
5555
--checkpointing_steps=250
5656
```
5757

58-
fp16训练需要显存47000MiB,为了更好地跟踪我们的训练实验,我们在上面的命令中使用了以下标志:
58+
fp16训练需要显存67000MiB,为了更好地跟踪我们的训练实验,我们在上面的命令中使用了以下标志:
5959

6060
* `report_to="wandb"` 将确保在 Weights and Biases 上跟踪训练运行。要使用它,请确保安装 `wandb`,使用 `pip install wandb`
6161
* `validation_prompt``validation_epochs` 允许脚本进行几次验证推理运行。这可以让我们定性地检查训练是否按预期进行。
@@ -118,7 +118,7 @@ python train_dreambooth_lora_sd3.py \
118118
--checkpointing_steps=250
119119
```
120120

121-
训练完成后,我们可以通过以下python脚本执行推理:
121+
fp16训练需要显存47000MiB,。训练完成后,我们可以通过以下python脚本执行推理:
122122
```python
123123
from ppdiffusers import StableDiffusion3Pipeline
124124
from ppdiffusers import (

ppdiffusers/ppdiffusers/models/transformer_sd3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]:
145145
# set recursively
146146
processors = {}
147147

148-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
148+
def fn_recursive_add_processors(name: str, module: paddle.nn.Layer, processors: Dict[str, AttentionProcessor]):
149149
if hasattr(module, "get_processor"):
150150
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
151151

@@ -178,7 +178,7 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
178178
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
179179
)
180180

181-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
181+
def fn_recursive_attn_processor(name: str, module: paddle.nn.Layer, processor):
182182
if hasattr(module, "set_processor"):
183183
if not isinstance(processor, dict):
184184
module.set_processor(processor)
@@ -267,7 +267,7 @@ def forward(
267267
# weight the lora layers by setting `lora_scale` for each PEFT layer
268268
scale_lora_layers(self, lora_scale)
269269
else:
270-
logger.info(
270+
logger.debug(
271271
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
272272
)
273273

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
from ppdiffusers import SD3Transformer2DModel
18+
from ppdiffusers.utils.testing_utils import (
19+
enable_full_determinism,
20+
paddle_device,
21+
)
22+
from .test_modeling_common import ModelTesterMixin
23+
24+
enable_full_determinism()
25+
26+
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
27+
model_class = SD3Transformer2DModel
28+
main_input_name = "hidden_states"
29+
@property
30+
def dummy_input(self):
31+
batch_size = 2
32+
num_channels = 4
33+
height = width = embedding_dim = 32
34+
pooled_embedding_dim = embedding_dim * 2
35+
sequence_length = 154
36+
hidden_states = paddle.randn((batch_size, num_channels, height, width))
37+
encoder_hidden_states = paddle.randn((batch_size, sequence_length, embedding_dim))
38+
pooled_prompt_embeds = paddle.randn((batch_size, pooled_embedding_dim))
39+
timestep = paddle.randint(0, 1000, shape=(batch_size,))
40+
return {
41+
"hidden_states": hidden_states,
42+
"encoder_hidden_states": encoder_hidden_states,
43+
"pooled_projections": pooled_prompt_embeds,
44+
"timestep": timestep,
45+
}
46+
@property
47+
def input_shape(self):
48+
return (4, 32, 32)
49+
@property
50+
def output_shape(self):
51+
return (4, 32, 32)
52+
def prepare_init_args_and_inputs_for_common(self):
53+
init_dict = {
54+
"sample_size": 32,
55+
"patch_size": 1,
56+
"in_channels": 4,
57+
"num_layers": 1,
58+
"attention_head_dim": 8,
59+
"num_attention_heads": 4,
60+
"caption_projection_dim": 32,
61+
"joint_attention_dim": 32,
62+
"pooled_projection_dim": 64,
63+
"out_channels": 4,
64+
}
65+
inputs_dict = self.dummy_input
66+
return init_dict, inputs_dict
67+
68+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
69+
def test_from_save_pretrained(self):
70+
pass
71+
72+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
73+
def test_outputs_equivalence(self):
74+
pass
75+
76+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
77+
def test_set_attn_processor_for_determinism(self):
78+
pass
79+

0 commit comments

Comments
 (0)