Skip to content

Commit d8f8b9a

Browse files
2 parents 4d1b1b4 + 1f196a0 commit d8f8b9a

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ images[0].save("cat_on_bench.png")
152152

153153
### Tweak prompts reusing seeds and latents
154154

155-
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
155+
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
156156

157157

158158
For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)

src/diffusers/models/resnet.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,39 +328,39 @@ def __init__(
328328
if self.use_nin_shortcut:
329329
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
330330

331-
def forward(self, x, temb, hey=False):
332-
h = x
331+
def forward(self, x, temb):
332+
hidden_states = x
333333

334334
# make sure hidden states is in float32
335335
# when running in half-precision
336-
h = self.norm1(h.float()).type(h.dtype)
337-
h = self.nonlinearity(h)
336+
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
337+
hidden_states = self.nonlinearity(hidden_states)
338338

339339
if self.upsample is not None:
340340
x = self.upsample(x)
341-
h = self.upsample(h)
341+
hidden_states = self.upsample(hidden_states)
342342
elif self.downsample is not None:
343343
x = self.downsample(x)
344-
h = self.downsample(h)
344+
hidden_states = self.downsample(hidden_states)
345345

346-
h = self.conv1(h)
346+
hidden_states = self.conv1(hidden_states)
347347

348348
if temb is not None:
349349
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350-
h = h + temb
350+
hidden_states = hidden_states + temb
351351

352352
# make sure hidden states is in float32
353353
# when running in half-precision
354-
h = self.norm2(h.float()).type(h.dtype)
355-
h = self.nonlinearity(h)
354+
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
355+
hidden_states = self.nonlinearity(hidden_states)
356356

357-
h = self.dropout(h)
358-
h = self.conv2(h)
357+
hidden_states = self.dropout(hidden_states)
358+
hidden_states = self.conv2(hidden_states)
359359

360360
if self.conv_shortcut is not None:
361361
x = self.conv_shortcut(x)
362362

363-
out = (x + h) / self.output_scale_factor
363+
out = (x + hidden_states) / self.output_scale_factor
364364

365365
return out
366366

src/diffusers/pipelines/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-fr
7777
- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
7878
use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
7979
logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.
80-
- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs.
81-
82-
We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/commmunity). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better.
80+
- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/commmunity). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better.
8381
- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*.
8482

8583
## Examples

0 commit comments

Comments
 (0)