Closed
Description
Describe the bug
Running diffusers.utils.export_to_video()
on the output of HunyuanVideoPipeline
results in
/app/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")
After adding some checks to numpy_to_pil()
in image_processor.py
I have confirmed that the output contains NaN
values
File "/app/pipeline.py", line 37, in <module>
output = pipe(
^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 677, in __call__
video = self.video_processor.postprocess_video(video, output_type=output_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/diffusers/src/diffusers/video_processor.py", line 103, in postprocess_video
batch_output = self.postprocess(batch_vid, output_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/diffusers/src/diffusers/image_processor.py", line 823, in postprocess
return self.numpy_to_pil(image)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/diffusers/src/diffusers/image_processor.py", line 158, in numpy_to_pil
raise ValueError("Image array contains NaN values")
ValueError: Image array contains NaN values
Reproduction
import os
import time
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
from torch.profiler import ProfilerActivity, profile, record_function
os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_ID = "tencent/HunyuanVideo"
PROMPT = "a whale shark floating through outer space"
profile_dir = os.environ.get("PROFILE_OUT_PATH", "./")
profile_file_name = os.environ.get("PROFILE_OUT_FILE_NAME", "hunyuan_profile.json")
profile_path = os.path.join(profile_dir, profile_file_name)
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
MODEL_ID, subfolder="transformer", torch_dtype=torch.float16, revision="refs/pr/18"
)
pipe = HunyuanVideoPipeline.from_pretrained(
MODEL_ID, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18"
)
pipe.vae.enable_tiling()
pipe.to("cuda")
print(f"\nStarting profiling of {MODEL_ID}\n")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True
) as prof:
with record_function("model_inference"):
output = pipe(
prompt=PROMPT,
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
)
# Export and print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace(profile_path)
print(f"{profile_file_name} ready")
# export video
video = output.frames[0]
print(" ====== raw video matrix =====")
print(video)
print()
print(" ====== Exporting video =====")
export_to_video(video, "hunyuan_example.mp4", fps=15)
print()
Logs
No response
System Info
GPU: AMD MI300X
ARG BASE_IMAGE=python:3.11-slim
FROM ${BASE_IMAGE}
ENV PYTHONBUFFERED=true
ENV CUDA_VISIBLE_DEVICES=0
WORKDIR /app
# Install tools
RUN apt-get update && \
apt-get install -y --no-install-recommends \
git \
libgl1-mesa-glx \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libfontconfig1 \
ffmpeg \
build-essential && \
rm -rf /var/lib/apt/lists/*
# install ROCm pytorch and python dependencies
RUN python -m pip install --no-cache-dir \
torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2 && \
python -m pip install --no-cache-dir \
accelerate transformers sentencepiece protobuf opencv-python imageio imageio-ffmpeg
# install diffusers from source to include newest pipeline classes
COPY diffusers diffusers
RUN cd diffusers && \
python -m pip install -e .
# Copy the profiling script
ARG PIPELINE_FILE
COPY ${PIPELINE_FILE} pipeline.py
# run the script
CMD ["python", "pipeline.py"]