Skip to content

Commit 0e6a840

Browse files
authored
[core] Support VideoToVideo with CogVideoX (#9333)
* add vid2vid pipeline for cogvideox * make fix-copies * update docs * fake context parallel cache, vae encode tiling * add test for cog vid2vid * use video link from HF docs repo * add copied from comments; correctly rename test class
1 parent af6c0fb commit 0e6a840

File tree

9 files changed

+1190
-20
lines changed

9 files changed

+1190
-20
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
9898
- all
9999
- __call__
100100

101+
## CogVideoXVideoToVideoPipeline
102+
103+
[[autodoc]] CogVideoXVideoToVideoPipeline
104+
- all
105+
- __call__
106+
101107
## CogVideoXPipelineOutput
102108

103-
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
109+
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@
255255
"BlipDiffusionPipeline",
256256
"CLIPImageProjection",
257257
"CogVideoXPipeline",
258+
"CogVideoXVideoToVideoPipeline",
258259
"CycleDiffusionPipeline",
259260
"FluxControlNetPipeline",
260261
"FluxPipeline",
@@ -699,6 +700,7 @@
699700
AuraFlowPipeline,
700701
CLIPImageProjection,
701702
CogVideoXPipeline,
703+
CogVideoXVideoToVideoPipeline,
702704
CycleDiffusionPipeline,
703705
FluxControlNetPipeline,
704706
FluxPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
"AudioLDM2UNet2DConditionModel",
133133
]
134134
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
135-
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
135+
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
136136
_import_structure["controlnet"].extend(
137137
[
138138
"BlipDiffusionControlNetPipeline",
@@ -454,7 +454,7 @@
454454
)
455455
from .aura_flow import AuraFlowPipeline
456456
from .blip_diffusion import BlipDiffusionPipeline
457-
from .cogvideo import CogVideoXPipeline
457+
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
458458
from .controlnet import (
459459
BlipDiffusionControlNetPipeline,
460460
StableDiffusionControlNetImg2ImgPipeline,

src/diffusers/pipelines/cogvideo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
26+
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
2627

2728
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2829
try:
@@ -33,6 +34,7 @@
3334
from ...utils.dummy_torch_and_transformers_objects import *
3435
else:
3536
from .pipeline_cogvideox import CogVideoXPipeline
37+
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
3638

3739
else:
3840
import sys

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import inspect
1717
import math
18-
from dataclasses import dataclass
1918
from typing import Callable, Dict, List, Optional, Tuple, Union
2019

2120
import torch
@@ -26,9 +25,10 @@
2625
from ...models.embeddings import get_3d_rotary_pos_embed
2726
from ...pipelines.pipeline_utils import DiffusionPipeline
2827
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29-
from ...utils import BaseOutput, logging, replace_example_docstring
28+
from ...utils import logging, replace_example_docstring
3029
from ...utils.torch_utils import randn_tensor
3130
from ...video_processor import VideoProcessor
31+
from .pipeline_output import CogVideoXPipelineOutput
3232

3333

3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -136,21 +136,6 @@ def retrieve_timesteps(
136136
return timesteps, num_inference_steps
137137

138138

139-
@dataclass
140-
class CogVideoXPipelineOutput(BaseOutput):
141-
r"""
142-
Output class for CogVideo pipelines.
143-
144-
Args:
145-
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
146-
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
147-
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
148-
`(batch_size, num_frames, channels, height, width)`.
149-
"""
150-
151-
frames: torch.Tensor
152-
153-
154139
class CogVideoXPipeline(DiffusionPipeline):
155140
r"""
156141
Pipeline for text-to-video generation using CogVideoX.

0 commit comments

Comments
 (0)