Skip to content

Commit 4c45eb6

Browse files
authored
ADD SD3 batch_parallel
ADD SD3 batch_parallel
2 parents cc90caa + e42dfc8 commit 4c45eb6

File tree

6 files changed

+120
-33
lines changed

6 files changed

+120
-33
lines changed

paddlemix/triton_ops/triton_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,9 +1711,18 @@ def split_concat(x, y):
17111711
out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
17121712
out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
17131713
grid = ("3", "batch", "seq_qkv + seq_eqkv")
1714-
1714+
# -1 means this value does not matter for triton compilation
17151715
split_concat_kernel[(op_name, grid)](
1716-
out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=BLOCK_SIZE
1716+
out0,
1717+
out1,
1718+
out2,
1719+
x,
1720+
y,
1721+
-1, # batch,
1722+
seq_qkv,
1723+
seq_eqkv,
1724+
ouput_hidden,
1725+
BLOCK_SIZE=BLOCK_SIZE
17171726
)
17181727

17191728
if in_dynamic_or_pir_mode():

ppdiffusers/deploy/sd3/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,29 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height
2929
| Paddle Inference| PyTorch | Paddle 动态图 |
3030
| --------------- | ------------ | ------------ |
3131
| 1.2 s | 1.78 s | 4.202 s |
32+
33+
34+
## Paddle Stable Diffusion 3 模型多卡推理:
35+
### batch parallel 实现原理
36+
- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2;
37+
所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。
38+
计算完成后,我们再把两张卡的计算结果 聚合在一起,结果与单卡计算完全一致。
39+
### 开启多卡推理方法
40+
- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `--inference_optimize_bp 1` 来开启这一功能,
41+
使用 `python -m paddle.distributed.launch --gpus 0,1` 指定使用哪些卡进行推理。
42+
高性能多卡推理指令:
43+
```shell
44+
# 执行多卡推理指令
45+
python -m paddle.distributed.launch --gpus 0,1 text_to_image_generation-stable_diffusion_3.py \
46+
--dtype float16 \
47+
--height 512 --width 512 \
48+
--num-inference-steps 50 \
49+
--inference_optimize 1 \
50+
--inference_optimize_bp 1 \
51+
--benchmark 1
52+
```
53+
## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下:
54+
55+
| Paddle batch parallel | Paddle Single Card | PyTorch | Paddle 动态图 |
56+
| --------------------- | ------------------ | --------- | ------------ |
57+
| 0.86 s | 1.2 s | 1.78 s | 4.202 s |

ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
16-
os.environ["FLAGS_use_cuda_managed_memory"] = "true"
1715
import argparse
18-
import datetime
19-
2016
import paddle
21-
22-
from ppdiffusers import StableDiffusion3Pipeline
23-
24-
2517
def parse_args():
2618
parser = argparse.ArgumentParser(
2719
description=" Use PaddleMIX to accelerate the Stable Diffusion3 image generation model."
@@ -30,13 +22,19 @@ def parse_args():
3022
"--benchmark",
3123
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
3224
default=False,
33-
help="if benchmark is set to True, measure inference performance",
25+
help="if set to True, measure inference performance",
3426
)
3527
parser.add_argument(
3628
"--inference_optimize",
3729
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
3830
default=False,
39-
help="If inference_optimize is set to True, all optimizations except Triton are enabled.",
31+
help="If set to True, all optimizations except Triton are enabled.",
32+
)
33+
parser.add_argument(
34+
"--inference_optimize_bp",
35+
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
36+
default=False,
37+
help="If set to True, batch parallel is enabled in DIT and dual-GPU acceleration is used.",
4038
)
4139
parser.add_argument("--height", type=int, default=512, help="Height of the generated image.")
4240
parser.add_argument("--width", type=int, default=512, help="Width of the generated image.")
@@ -51,11 +49,38 @@ def parse_args():
5149
if args.inference_optimize:
5250
os.environ["INFERENCE_OPTIMIZE"] = "True"
5351
os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"
52+
if args.inference_optimize_bp:
53+
os.environ["INFERENCE_OPTIMIZE_BP"] = "True"
5454
if args.dtype == "float32":
5555
inference_dtype = paddle.float32
5656
elif args.dtype == "float16":
5757
inference_dtype = paddle.float16
5858

59+
60+
if args.inference_optimize_bp:
61+
from paddle.distributed import fleet
62+
from paddle.distributed.fleet.utils import recompute
63+
import numpy as np
64+
import random
65+
import paddle.distributed as dist
66+
import paddle.distributed.fleet as fleet
67+
strategy = fleet.DistributedStrategy()
68+
model_parallel_size = 2
69+
data_parallel_size = 1
70+
strategy.hybrid_configs = {
71+
"dp_degree": data_parallel_size,
72+
"mp_degree": model_parallel_size,
73+
"pp_degree": 1
74+
}
75+
fleet.init(is_collective=True, strategy=strategy)
76+
hcg = fleet.get_hybrid_communicate_group()
77+
mp_id = hcg.get_model_parallel_rank()
78+
rank_id = dist.get_rank()
79+
80+
import datetime
81+
from ppdiffusers import StableDiffusion3Pipeline
82+
83+
5984
pipe = StableDiffusion3Pipeline.from_pretrained(
6085
"stabilityai/stable-diffusion-3-medium-diffusers",
6186
paddle_dtype=inference_dtype,
@@ -67,6 +92,7 @@ def parse_args():
6792
enable_new_ir=True,
6893
cache_static_model=True,
6994
exp_enable_use_cutlass=True,
95+
delete_pass_lists=["add_norm_fuse_pass"],
7096
)
7197

7298
generator = paddle.Generator().manual_seed(42)
@@ -111,4 +137,8 @@ def parse_args():
111137
cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3)
112138
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")
113139

114-
image.save("text_to_image_generation-stable_diffusion_3-result.png")
140+
if args.inference_optimize_bp:
141+
if rank_id == 0:
142+
image.save("text_to_image_generation-stable_diffusion_3-result.png")
143+
else:
144+
image.save("text_to_image_generation-stable_diffusion_3-result.png")

ppdiffusers/ppdiffusers/models/simplified_sd3.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
106106
qkv = self.qkv[i](norm_hidden_states)
107107
eqkv = self.eqkv[i](norm_encoder_hidden_states)
108108
q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv)
109-
q = q.reshape([2, -1, 24, 64])
110-
k = k.reshape([2, -1, 24, 64])
111-
v = v.reshape([2, -1, 24, 64])
109+
bs = hidden_states.shape[0]
110+
q = q.reshape([bs, -1, 24, 64])
111+
k = k.reshape([bs, -1, 24, 64])
112+
v = v.reshape([bs, -1, 24, 64])
112113

113114
norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False)
114-
norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim])
115+
norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, self.dim])
115116
attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1)
116117

117118
# attn_output, context_attn_output = paddlemix.triton_ops.triton_split(
@@ -155,7 +156,5 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
155156
last_context_ffn_output = context_ffn_output
156157
last_context_hidden_states = encoder_hidden_states
157158
last_context_gate_mlp = c_gate_mlp
158-
else:
159-
encoder_hidden_states = None
160159

161-
return encoder_hidden_states, hidden_states
160+
return hidden_states

ppdiffusers/ppdiffusers/models/transformer_sd3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,9 @@ def forward(
329329
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
330330

331331
if self.inference_optimize:
332-
out = self.simplified_sd3(
332+
hidden_states = self.simplified_sd3(
333333
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
334334
)
335-
hidden_states = out[1]
336335
encoder_hidden_states = None
337336
else:
338337
encoder_hidden_states, hidden_states = self.sd3_origin_transformer(

ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
15+
import os
1616
import inspect
1717
from typing import Any, Callable, Dict, List, Optional, Union
1818

1919
import paddle
20+
import paddle.distributed as dist
2021

2122
from ppdiffusers.transformers import ( # T5TokenizerFast,
2223
CLIPTextModelWithProjection,
@@ -195,6 +196,7 @@ def __init__(
195196
if hasattr(self, "transformer") and self.transformer is not None
196197
else 128
197198
)
199+
self.inference_optimize_bp = os.getenv("INFERENCE_OPTIMIZE_BP") == "True"
198200

199201
def _get_t5_prompt_embeds(
200202
self,
@@ -229,7 +231,6 @@ def _get_t5_prompt_embeds(
229231
"The following part of your input was truncated because CLIP can only handle sequences up to"
230232
f" {self.tokenizer_max_length} tokens: {removed_text}"
231233
)
232-
# breakpoint()
233234
prompt_embeds = self.text_encoder_3(text_input_ids)[0]
234235

235236
dtype = self.text_encoder_3.dtype
@@ -395,7 +396,6 @@ def encode_prompt(
395396

396397
prompt_embeds = paddle.concat([clip_prompt_embeds, t5_prompt_embed], axis=-2)
397398
pooled_prompt_embeds = paddle.concat([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1)
398-
399399
if do_classifier_free_guidance and negative_prompt_embeds is None:
400400
negative_prompt = negative_prompt or ""
401401
negative_prompt_2 = negative_prompt_2 or negative_prompt
@@ -707,7 +707,6 @@ def __call__(
707707
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
708708
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
709709
`._callback_tensor_inputs` attribute of your pipeline class.
710-
711710
Examples:
712711
713712
Returns:
@@ -801,22 +800,47 @@ def __call__(
801800
latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents
802801
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
803802
timestep = t.expand(latent_model_input.shape[0])
803+
if self.inference_optimize_bp and self.do_classifier_free_guidance:
804+
latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0)
805+
timestep_input ,timestep_ = paddle.split(timestep,2,axis=0)
806+
prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0)
807+
pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0)
808+
809+
dist.scatter(latent_input,[latent_input,latent_model_input_])
810+
dist.scatter(timestep_input,[timestep_input,timestep_])
811+
dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_])
812+
dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_])
804813

814+
else:
815+
latent_input = latent_model_input
816+
timestep_input = timestep
817+
prompt_embeds_input = prompt_embeds
818+
pooled_prompt_embeds_input = pooled_prompt_embeds
819+
805820
model_output = self.transformer(
806-
hidden_states=latent_model_input,
807-
timestep=timestep,
808-
encoder_hidden_states=prompt_embeds,
809-
pooled_projections=pooled_prompt_embeds,
821+
hidden_states=latent_input,
822+
timestep=timestep_input,
823+
encoder_hidden_states=prompt_embeds_input,
824+
pooled_projections=pooled_prompt_embeds_input,
810825
joint_attention_kwargs=self.joint_attention_kwargs,
811826
return_dict=False,
812827
)
813-
814828
if is_inference_mode(self.transformer):
815829
# NOTE:(changwenbin,zhoukangkang)
816830
# This is for paddle inference mode
817-
noise_pred = model_output
831+
output = model_output
832+
else:
833+
output = model_output[0]
834+
835+
if self.inference_optimize_bp:
836+
tmp_shape = output.shape
837+
tmp_shape[0] *=2
838+
noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype)
839+
dist.all_gather(noise_pred,output)
818840
else:
819-
noise_pred = model_output[0]
841+
noise_pred = output
842+
843+
820844

821845
# perform guidance
822846
if self.do_classifier_free_guidance:

0 commit comments

Comments
 (0)