Skip to content

Commit 7c7fd17

Browse files
committed
Manually fused padding with convolution to resolve slice_forward compilation issue.
1 parent 7caeb6e commit 7c7fd17

File tree

4 files changed

+110
-88
lines changed

4 files changed

+110
-88
lines changed

examples/stable_diffusion/run_sdxl.py

100644100755
Lines changed: 65 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,68 @@
1010
world_size = tensorrt_llm.mpi_world_size()
1111
rank = tensorrt_llm.mpi_rank()
1212

13-
parser = argparse.ArgumentParser(
14-
description='run SDXL with the UNet TensorRT engine.')
15-
parser.add_argument('--size', type=int, default=1024)
16-
parser.add_argument('--seed', type=int, default=233)
17-
parser.add_argument('--num_inference_steps', type=int, default=50)
18-
parser.add_argument(
19-
'--prompt',
20-
type=str,
21-
default=
22-
"masterpiece, gouache painting, 1girl, distant view, lone boat, willow trees"
23-
)
24-
parser.add_argument('--model_dir',
25-
type=str,
26-
default=None,
27-
help='model directory')
28-
29-
args = parser.parse_args()
30-
size = args.size
31-
seed = args.seed
32-
prompt = args.prompt
33-
num_inference_steps = args.num_inference_steps
34-
model_dir = f'sdxl_s{size}_w{world_size}' if args.model_dir is None else args.model_dir
35-
36-
pipeline = StableDiffusionXLPipeline.from_pretrained(
37-
"stabilityai/stable-diffusion-xl-base-1.0",
38-
torch_dtype=torch.float16,
39-
use_safetensors=True,
40-
)
41-
pipeline.set_progress_bar_config(disable=rank != 0)
42-
pipeline.prepare(f'sdxl_s{size}_w{world_size}', size)
43-
pipeline.to('cuda')
44-
45-
li = []
46-
for i in range(10):
47-
st = time.time()
48-
image = pipeline(num_inference_steps=num_inference_steps,
49-
prompt=prompt,
50-
generator=torch.Generator(device="cuda").manual_seed(seed),
51-
height=size,
52-
width=size).images[0]
53-
ed = time.time()
54-
li.append(ed - st)
55-
56-
if rank == 0:
57-
print(f'Avg latency: {np.sum(li[-7:]) / 7.0}s')
58-
image.save(f"output.png")
13+
14+
def parseArgs():
15+
parser = argparse.ArgumentParser(
16+
description='run SDXL with the UNet TensorRT engine.')
17+
parser.add_argument('--size', type=int, default=1024)
18+
parser.add_argument('--seed', type=int, default=233)
19+
parser.add_argument('--num_inference_steps', type=int, default=50)
20+
parser.add_argument(
21+
'--prompt',
22+
type=str,
23+
default=
24+
"masterpiece, gouache painting, 1girl, distant view, lone boat, willow trees"
25+
)
26+
parser.add_argument('--model_dir',
27+
type=str,
28+
default=None,
29+
help='model directory')
30+
parser.add_argument('--num-warmup-runs', type=int, default=3)
31+
parser.add_argument('--avg-runs', type=int, default=10)
32+
return parser.parse_args()
33+
34+
35+
if __name__ == "__main__":
36+
args = parseArgs()
37+
size = args.size
38+
seed = args.seed
39+
prompt = args.prompt
40+
num_inference_steps = args.num_inference_steps
41+
model_dir = f'sdxl_s{size}_w{world_size}' if args.model_dir is None else args.model_dir
42+
num_warmup_runs = args.num_warmup_runs
43+
avg_runs = args.avg_runs
44+
45+
pipeline = StableDiffusionXLPipeline.from_pretrained(
46+
"stabilityai/stable-diffusion-xl-base-1.0",
47+
torch_dtype=torch.float16,
48+
use_safetensors=True,
49+
)
50+
pipeline.set_progress_bar_config(disable=rank != 0)
51+
pipeline.prepare(f'sdxl_s{size}_w{world_size}', size)
52+
pipeline.to('cuda')
53+
54+
# warm up
55+
for i in range(num_warmup_runs):
56+
image = pipeline(
57+
num_inference_steps=num_inference_steps,
58+
prompt=prompt,
59+
generator=torch.Generator(device="cuda").manual_seed(seed),
60+
height=size,
61+
width=size).images[0]
62+
63+
latency = []
64+
for i in range(avg_runs):
65+
st = time.time()
66+
image = pipeline(
67+
num_inference_steps=num_inference_steps,
68+
prompt=prompt,
69+
generator=torch.Generator(device="cuda").manual_seed(seed),
70+
height=size,
71+
width=size,).images[0]
72+
ed = time.time()
73+
latency.append(ed - st)
74+
75+
if rank == 0:
76+
print(f'Avg latency: {np.sum(latency) / avg_runs}s')
77+
image.save(f"output.png")

tensorrt_llm/functional.py

100644100755
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3441,13 +3441,17 @@ def conv1d(input: Tensor,
34413441
return output_1d
34423442

34433443

3444-
def conv2d(input: Tensor,
3445-
weight: Tensor,
3446-
bias: Optional[Tensor] = None,
3447-
stride: Tuple[int, int] = (1, 1),
3448-
padding: Tuple[int, int] = (0, 0),
3449-
dilation: Tuple[int, int] = (1, 1),
3450-
groups: int = 1) -> Tensor:
3444+
def conv2d(
3445+
input: Tensor,
3446+
weight: Tensor,
3447+
bias: Optional[Tensor] = None,
3448+
stride: Tuple[int, int] = (1, 1),
3449+
padding: Tuple[int, int] = (0, 0),
3450+
dilation: Tuple[int, int] = (1, 1),
3451+
groups: int = 1,
3452+
pre_padding: Optional[Tuple[int, int]] = None,
3453+
post_padding: Optional[Tuple[int, int]] = None
3454+
) -> Tensor:
34513455
##
34523456
## TODO: Document that function!
34533457
##
@@ -3475,6 +3479,10 @@ def conv2d(input: Tensor,
34753479
layer.dilation_nd = dilation
34763480
layer.num_groups = groups
34773481
layer.dilation_nd = dilation
3482+
if pre_padding:
3483+
layer.pre_padding = pre_padding
3484+
if post_padding:
3485+
layer.post_padding = post_padding
34783486

34793487
if not is_weight_constant:
34803488
layer.set_input(1, weight.trt_tensor)

tensorrt_llm/models/unet/embeddings.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License.
1515
import math
1616

17-
from ..._utils import fp32_array
17+
import tensorrt as trt
18+
19+
from ..._utils import fp16_array, fp32_array
1820
from ...functional import concat, constant, cos, exp, silu, sin
1921
from ...layers import Linear
2022
from ...module import Module
@@ -43,7 +45,10 @@ def get_timestep_embedding(timesteps,
4345
for i in range(half_dim)
4446
]
4547

46-
emb = exp(constant(fp32_array(exponent)))
48+
if dtype is trt.float16:
49+
emb = exp(constant(fp16_array(exponent)))
50+
else:
51+
emb = exp(constant(fp32_array(exponent)))
4752

4853
ts_shape = list(timesteps.size())
4954
ts_shape.append(1)

tensorrt_llm/models/unet/pp/conv2d.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import tensorrt as trt
1615

17-
from ....functional import allgather, concat, conv2d, slice, stack, unsqueeze
16+
from ....functional import (allgather, concat, conv2d, slice, stack, unsqueeze)
1817
from ....layers import Conv2d
1918
from ....mapping import Mapping
2019
from ....module import Module
2120

2221

23-
def pad(input, pad):
24-
assert input.ndim() == 4
25-
n, c, h, w = input.shape
26-
padded_input = slice(input,
27-
starts=[0, 0, -pad[2], -pad[0]],
28-
sizes=[n, c, pad[2] + h + pad[3], pad[0] + w + pad[1]],
29-
mode=trt.SampleMode.FILL,
30-
fill_value=0.0)
31-
return padded_input
32-
33-
3422
class DistriConv2dPP(Module):
3523

3624
def __init__(self,
@@ -54,20 +42,22 @@ def sliced_forward(self, x):
5442
idx = mapping.tp_rank
5543
h_begin = output_h * idx * stride - padding
5644
h_end = output_h * (idx + 1) * stride + padding
57-
final_padding = [padding, padding, 0, 0]
45+
pre_padding = [0, padding]
46+
post_padding = [0, padding]
5847
if h_begin < 0:
5948
h_begin = 0
60-
final_padding[2] = padding
49+
pre_padding[0] = padding
6150
if h_end > h:
6251
h_end = h
63-
final_padding[3] = padding
52+
post_padding[0] = padding
6453
sliced_input = slice(x, [0, 0, h_begin, 0], [b, c, h_end - h_begin, w])
65-
padded_input = pad(sliced_input, final_padding)
66-
return conv2d(padded_input,
54+
return conv2d(sliced_input,
6755
self.conv.weight.value,
6856
None if self.conv.bias is None else self.conv.bias.value,
6957
stride=self.conv.stride,
70-
padding=(0, 0))
58+
padding=(0, 0),
59+
pre_padding=tuple(pre_padding),
60+
post_padding=tuple(post_padding))
7161

7262
def forward(self, x, *args, **kwargs):
7363
mapping = self.mapping
@@ -78,14 +68,16 @@ def forward(self, x, *args, **kwargs):
7868
boundary_size = self.conv.padding[0]
7969

8070
def create_padded_x(x, boundaries):
71+
preH = 0
72+
postH = 0
8173
if mapping.tp_rank == 0:
8274
b = boundaries.select(0, mapping.tp_rank + 1).select(0, 0)
83-
concat_x = concat([x, b], dim=2)
84-
padded_x = pad(concat_x, [0, 0, boundary_size, 0])
75+
padded_x = concat([x, b], dim=2)
76+
preH = boundary_size
8577
elif mapping.tp_rank == mapping.tp_size - 1:
8678
b = boundaries.select(0, mapping.tp_rank - 1).select(0, 1)
87-
concat_x = concat([b, x], dim=2)
88-
padded_x = pad(concat_x, [0, 0, 0, boundary_size])
79+
padded_x = concat([b, x], dim=2)
80+
postH = boundary_size
8981
else:
9082
b0 = boundaries.select(0, mapping.tp_rank - 1).select(0, 1)
9183
b1 = boundaries.select(0, mapping.tp_rank + 1).select(0, 0)
@@ -97,7 +89,7 @@ def create_padded_x(x, boundaries):
9789
],
9890
dim=2,
9991
)
100-
return padded_x
92+
return padded_x, preH, postH
10193

10294
n, c, h, w = x.shape
10395
b0 = slice(x, [0, 0, 0, 0], [n, c, boundary_size, w])
@@ -107,13 +99,11 @@ def create_padded_x(x, boundaries):
10799

108100
boundaries = allgather(unsqueeze(boundary, 0),
109101
group=mapping.tp_group)
110-
padded_x = create_padded_x(x, boundaries)
111-
output = conv2d(
112-
padded_x,
113-
self.conv.weight.value,
114-
self.conv.bias.value,
115-
stride=self.conv.stride,
116-
padding=(0, self.conv.padding[1]),
117-
)
118-
102+
padded_x, preH, postH = create_padded_x(x, boundaries)
103+
output = conv2d(padded_x,
104+
self.conv.weight.value,
105+
self.conv.bias.value,
106+
stride=self.conv.stride,
107+
pre_padding=(preH, self.conv.padding[1]),
108+
post_padding=(postH, self.conv.padding[1]))
119109
return output

0 commit comments

Comments
 (0)