From 7c4eea7cbece2e8ce09326dfedf97636f3b44bdf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 11:59:26 +0530 Subject: [PATCH 01/11] boom boom --- .../stable_diffusion/image_to_image.py | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion/image_to_image.py diff --git a/src/diffusers/pipelines/stable_diffusion/image_to_image.py b/src/diffusers/pipelines/stable_diffusion/image_to_image.py new file mode 100644 index 000000000000..831e436b91a7 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/image_to_image.py @@ -0,0 +1,177 @@ +import inspect +import warnings +from typing import List, Optional, Union + +import torch + +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from .safety_checker import StableDiffusionSafetyChecker + + +class StableDiffusionPipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: torch.FloatTensor, + strength: float = 0.8, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + init_latents = self.vae.encode(init_image.to(device)).sample() + init_latents = 0.18215 * latents + + # add noise to latents + noise = torch.randn(init_latents.shape).to(init_latents.device) + t = int(num_inference_steps * strength) + timesteps = torch.tensor([t] * batch_size, dtype=torch.long, device=device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the intial random noise + # latents = torch.randn( + # (batch_size, self.unet.in_channels, height // 8, width // 8), + # generator=generator, + # device=self.device, + # ) + latents = init_latents + + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in tqdm(enumerate(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents) + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + return {"sample": image, "nsfw_content_detected": has_nsfw_concept} From e47b5f9e8a4fd324d8eed12adc92a964cb926ad5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 22:36:20 +0530 Subject: [PATCH 02/11] reorganise examples --- examples/{ => training}/README.md | 0 examples/{ => training}/train_unconditional.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename examples/{ => training}/README.md (100%) rename examples/{ => training}/train_unconditional.py (100%) diff --git a/examples/README.md b/examples/training/README.md similarity index 100% rename from examples/README.md rename to examples/training/README.md diff --git a/examples/train_unconditional.py b/examples/training/train_unconditional.py similarity index 100% rename from examples/train_unconditional.py rename to examples/training/train_unconditional.py From f8b01da60ce45ed0b23a991f0da7b157a874c329 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 22:36:29 +0530 Subject: [PATCH 03/11] add image2image in example inference --- .../inference/image2image.py | 79 ++++++------------- 1 file changed, 25 insertions(+), 54 deletions(-) rename src/diffusers/pipelines/stable_diffusion/image_to_image.py => examples/inference/image2image.py (67%) diff --git a/src/diffusers/pipelines/stable_diffusion/image_to_image.py b/examples/inference/image2image.py similarity index 67% rename from src/diffusers/pipelines/stable_diffusion/image_to_image.py rename to examples/inference/image2image.py index 831e436b91a7..4d34aff133d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/image_to_image.py +++ b/examples/inference/image2image.py @@ -1,26 +1,22 @@ import inspect -import warnings from typing import List, Optional, Union import torch +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from .safety_checker import StableDiffusionSafetyChecker - -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionImg2ImgPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -42,26 +38,12 @@ def __call__( prompt: Union[str, List[str]], init_image: torch.FloatTensor, strength: float = 0.8, - height: Optional[int] = 512, - width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", - **kwargs, ): - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) if isinstance(prompt, str): batch_size = 1 @@ -70,24 +52,30 @@ def __call__( else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} + offset = 0 if accepts_offset: + offset = 1 extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - init_latents = self.vae.encode(init_image.to(device)).sample() - init_latents = 0.18215 * latents - - # add noise to latents - noise = torch.randn(init_latents.shape).to(init_latents.device) - t = int(num_inference_steps * strength) - timesteps = torch.tensor([t] * batch_size, dtype=torch.long, device=device) + + # encode the init image into latents and scale the latents + init_latents = self.vae.encode(init_image.to(self.device)).sample() + init_latents = 0.18215 * init_latents + + # prepare init_latents noise to latents + init_latents = torch.cat([init_latents] * batch_size) + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + timesteps = self.scheduler.timesteps[max(-init_timestep, 0)] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings @@ -117,19 +105,6 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # get the intial random noise - # latents = torch.randn( - # (batch_size, self.unet.in_channels, height // 8, width // 8), - # generator=generator, - # device=self.device, - # ) - latents = init_latents - - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -139,12 +114,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in tqdm(enumerate(self.scheduler.timesteps)): + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] @@ -155,10 +129,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From d5f644ab4c2f34482d44676293b1a2b1345aa99f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 22:51:54 +0530 Subject: [PATCH 04/11] add readme --- .../{image2image.py => image_to_image.py} | 12 +++++ examples/inference/readme.md | 48 +++++++++++++++++++ 2 files changed, 60 insertions(+) rename examples/inference/{image2image.py => image_to_image.py} (94%) create mode 100644 examples/inference/readme.md diff --git a/examples/inference/image2image.py b/examples/inference/image_to_image.py similarity index 94% rename from examples/inference/image2image.py rename to examples/inference/image_to_image.py index 4d34aff133d3..89f56851f445 100644 --- a/examples/inference/image2image.py +++ b/examples/inference/image_to_image.py @@ -1,14 +1,26 @@ import inspect from typing import List, Optional, Union +import numpy as np import torch +import PIL from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + class StableDiffusionImg2ImgPipeline(DiffusionPipeline): def __init__( self, diff --git a/examples/inference/readme.md b/examples/inference/readme.md new file mode 100644 index 000000000000..43f22105156e --- /dev/null +++ b/examples/inference/readme.md @@ -0,0 +1,48 @@ +# Inference Examples + +## Installing the dependencies + +Before running the scipts, make sure to install the library's training dependencies: + +```bash +pip install diffusers transformers ftfy +``` + +## Image-to-Image text-guided generation with Stable Diffusion + +The `image_to_image.py` implements `StableDiffusionImg2ImgPipeline`, it let's you pass a text prompt and an initial image to condition on to generate new images. This examples also showcases how you can write custom diffusion pipelines using `diffusers`. + +### How to use it + + +```python +from image_to_image import StableDiffusionImg2ImgPipeline, preprocess +import requests +from PIL import Image +from io import BytesIO + +# load the pipeline +device = "cuda" +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_path, + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=True +).to(device) + +# let's download an initial image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) +init_image = preprocess(init_image) + +prompt = "A fantasy landscape, trending on artstation" + +with autocast("cuda"): + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"] + +images[0].save("fantasy_landscape.png") +``` +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yHBat74l_fvS9f4IDLvquvJSyY7G36G0?usp=sharing) \ No newline at end of file From 7ced86dcb4dffbb225389a7aef3c061700a857b3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 22:55:06 +0530 Subject: [PATCH 05/11] fix example --- examples/inference/readme.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/inference/readme.md b/examples/inference/readme.md index 43f22105156e..d065e407f090 100644 --- a/examples/inference/readme.md +++ b/examples/inference/readme.md @@ -16,15 +16,17 @@ The `image_to_image.py` implements `StableDiffusionImg2ImgPipeline`, it let's yo ```python -from image_to_image import StableDiffusionImg2ImgPipeline, preprocess +from torch import autocast import requests from PIL import Image from io import BytesIO +from image_to_image import StableDiffusionImg2ImgPipeline, preprocess + # load the pipeline device = "cuda" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - model_path, + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True From 88a68c5bd42f1288210c528a094defe3d2bba5fc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 22 Aug 2022 23:05:46 +0530 Subject: [PATCH 06/11] update colab url --- examples/inference/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inference/readme.md b/examples/inference/readme.md index d065e407f090..ff96101b63c9 100644 --- a/examples/inference/readme.md +++ b/examples/inference/readme.md @@ -47,4 +47,4 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yHBat74l_fvS9f4IDLvquvJSyY7G36G0?usp=sharing) \ No newline at end of file +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/diffusers_image_2image.ipynb) \ No newline at end of file From 338ecd50a58d1136c090396978c7cd3863837029 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 23 Aug 2022 00:08:56 +0530 Subject: [PATCH 07/11] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/inference/readme.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/inference/readme.md b/examples/inference/readme.md index ff96101b63c9..541099d78dff 100644 --- a/examples/inference/readme.md +++ b/examples/inference/readme.md @@ -2,7 +2,7 @@ ## Installing the dependencies -Before running the scipts, make sure to install the library's training dependencies: +Before running the scipts, make sure to install the library's dependencies: ```bash pip install diffusers transformers ftfy @@ -10,7 +10,7 @@ pip install diffusers transformers ftfy ## Image-to-Image text-guided generation with Stable Diffusion -The `image_to_image.py` implements `StableDiffusionImg2ImgPipeline`, it let's you pass a text prompt and an initial image to condition on to generate new images. This examples also showcases how you can write custom diffusion pipelines using `diffusers`. +The `image_to_image.py` script implements `StableDiffusionImg2ImgPipeline`. It lets you pass a text prompt and an initial image to condition the generation of new images. This example also showcases how you can write custom diffusion pipelines using `diffusers`! ### How to use it From 17c22123a650462a3e75aae3e4533b2924119db7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 23 Aug 2022 13:57:48 +0530 Subject: [PATCH 08/11] fix init_timestep --- examples/inference/image_to_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/inference/image_to_image.py b/examples/inference/image_to_image.py index 89f56851f445..4b88112c4852 100644 --- a/examples/inference/image_to_image.py +++ b/examples/inference/image_to_image.py @@ -83,7 +83,8 @@ def __call__( # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset - timesteps = self.scheduler.timesteps[max(-init_timestep, 0)] + init_timestep = min(init_timestep, num_inference_steps) + timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps From e4a1a8f1392d8340cd2982ec60e05e01accc003d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 23 Aug 2022 15:52:03 +0530 Subject: [PATCH 09/11] update colab url --- examples/inference/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inference/readme.md b/examples/inference/readme.md index 541099d78dff..e61004e47ec2 100644 --- a/examples/inference/readme.md +++ b/examples/inference/readme.md @@ -47,4 +47,4 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/diffusers_image_2image.ipynb) \ No newline at end of file +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) \ No newline at end of file From c0c3a19ed9a4f1c758e9be8ec5b1b6e774ef0976 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 23 Aug 2022 16:21:57 +0530 Subject: [PATCH 10/11] update main readme --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2b1322266279..5f65b8d3d15f 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,8 @@ More precisely, 🤗 Diffusers offers: - State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)). - Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)). - Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)). -- Training examples to show how to train the most popular diffusion models (see [examples](https://github.com/huggingface/diffusers/tree/main/examples)). +- Training examples to show how to train the most popular diffusion models (see [examples/training](https://github.com/huggingface/diffusers/tree/main/examples/training)). +- Inference examples to show how to create custom pipelines for advanced tasks such as image2image, in-painting (see [examples/inference](https://github.com/huggingface/diffusers/tree/main/examples/inference)) ## Quickstart From a311826f1a8fdbf1265ac86b157990a37df8e3b7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 23 Aug 2022 16:25:47 +0530 Subject: [PATCH 11/11] rename readme