-
Notifications
You must be signed in to change notification settings - Fork 6.1k
CogView4 (supports different length c and uc) #10649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 48 commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
2640bcf
init
zRzRzRzRzRzRzR eba11fa
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 6163679
encode with glm
zRzRzRzRzRzRzR 6090ea7
draft schedule
zRzRzRzRzRzRzR c7d1227
feat(scheduler): Add CogView scheduler implementation
OleehyO e9f6626
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 549b357
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 004d002
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR f4457fb
feat(embeddings): add CogView 2D rotary positional embedding
OleehyO 5f8d33b
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 9a93218
1
zRzRzRzRzRzRzR ca000dd
Update pipeline_cogview4.py
zRzRzRzRzRzRzR 7ab4a3f
fix the timestep init and sigma
zRzRzRzRzRzRzR 56ceaa6
update latent
zRzRzRzRzRzRzR a7179a2
draft patch(not work)
zRzRzRzRzRzRzR c9ddf50
Merge branch 'cogview4'
zRzRzRzRzRzRzR 2f30cc1
Merge pull request #2 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR e6b8907
fix
zRzRzRzRzRzRzR 0ab7260
[WIP][cogview4]: implement initial CogView4 pipeline
OleehyO f608f82
[WIP][cogview4][refactor]: Split condition/uncondition forward pass i…
OleehyO b86bfd4
use with -2 hidden state
zRzRzRzRzRzRzR c4d1e69
remove text_projector
zRzRzRzRzRzRzR 7916140
1
zRzRzRzRzRzRzR f8945ce
[WIP] Add tensor-reload to align input from transformer block
OleehyO bf7f322
[WIP] for older glm
zRzRzRzRzRzRzR dd6568b
use with cogview4 transformers forward twice of u and uc
zRzRzRzRzRzRzR 6f5407e
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 9e5b991
Update convert_cogview4_to_diffusers.py
zRzRzRzRzRzRzR 36b1682
remove this
zRzRzRzRzRzRzR 804f5cc
Merge pull request #3 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR 16c2397
use main example
zRzRzRzRzRzRzR 601696d
change back
zRzRzRzRzRzRzR 84115dc
reset
zRzRzRzRzRzRzR 95a103f
setback
zRzRzRzRzRzRzR d932f67
back
zRzRzRzRzRzRzR b04f15d
back 4
zRzRzRzRzRzRzR 5d33f3f
Fix qkv conversion logic for CogView4 to Diffusers format
zRzRzRzRzRzRzR b889b37
back5
zRzRzRzRzRzRzR e239c3c
revert to sat to cogview4 version
zRzRzRzRzRzRzR 310da29
update a new convert from megatron
zRzRzRzRzRzRzR 3bd6d30
[WIP][cogview4]: implement CogView4 attention processor
OleehyO f826aec
[cogview4] implement CogView4 transformer block
OleehyO 8d8ed8b
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR bf1fdc8
with new attn
zRzRzRzRzRzRzR 6a3a07f
[bugfix] fix dimension mismatch in CogView4 attention
OleehyO de274f3
[cogview4][WIP]: update final normalization in CogView4 transformer
OleehyO e94999e
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO e238284
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR a9b1e16
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 46277b2
1
zRzRzRzRzRzRzR ebbaa5b
put back
zRzRzRzRzRzRzR f1ccdd2
Update transformer_cogview4.py
zRzRzRzRzRzRzR 030a467
change time_shift
zRzRzRzRzRzRzR ad40575
Update pipeline_cogview4.py
zRzRzRzRzRzRzR 81d39ee
change timesteps
zRzRzRzRzRzRzR 45f9e88
fix
zRzRzRzRzRzRzR 1dbeaa8
change text_encoder_id
zRzRzRzRzRzRzR f209600
[cogview4][rope] align RoPE implementation with Megatron
OleehyO 992f5a3
[cogview4][bugfix] apply silu activation to time embeddings in CogView4
OleehyO 03a1c3b
[cogview4][chore] clean up pipeline code
OleehyO dd34794
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 3dab073
[cogview4][scheduler] Implement CogView4 scheduler and pipeline
OleehyO 63982d6
now It work
zRzRzRzRzRzRzR 90a5706
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR d4748e0
add timestep
zRzRzRzRzRzRzR 95f851d
batch
zRzRzRzRzRzRzR cb56282
change convert scipt
zRzRzRzRzRzRzR fedf325
refactor pt. 1; make style
a-r-r-o-w 90d29c7
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 4c01c9d
refactor pt. 2
a-r-r-o-w c1b8004
refactor pt. 3
a-r-r-o-w 9d55d0a
add tests
a-r-r-o-w 5e6de42
make fix-copies
a-r-r-o-w 30dd0ad
Merge branch 'main' into cogview4
a-r-r-o-w 2046cf2
update toctree.yml
a-r-r-o-w 39e1198
use flow match scheduler instead of custom
a-r-r-o-w b566a9f
Merge branch 'main' into cogview4
a-r-r-o-w b4c9fde
remove scheduling_cogview.py
a-r-r-o-w a137e17
add tiktoken to test dependencies
a-r-r-o-w da420fb
Update src/diffusers/models/embeddings.py
a-r-r-o-w 4003b9c
apply suggestions from review
a-r-r-o-w 35c0ec6
use diffusers apply_rotary_emb
a-r-r-o-w d328c5e
update flow match scheduler to accept timesteps
a-r-r-o-w d637d3a
Merge branch 'main' into cogview4
a-r-r-o-w 4c37ef0
fix comment
a-r-r-o-w 90c240b
apply review sugestions
a-r-r-o-w 5c11298
Merge branch 'main' into cogview4
a-r-r-o-w 2f12b7a
Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
import torch | ||
from collections import OrderedDict | ||
from diffusers import CogView4Transformer2DModel | ||
|
||
def load_state_dict_sat(file_path): | ||
"""Load the SAT state dictionary from a given file path.""" | ||
# Typically, the stored SAT ckpt is in the format: {'module': {...}} | ||
ckpt = torch.load(file_path, map_location="cuda") | ||
return ckpt["module"] | ||
|
||
|
||
def extract_qkv_from_sat(state_dict, layer_idx): | ||
""" | ||
Extract QKV weights and biases from a SAT state_dict. | ||
Expects keys like: | ||
model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value | ||
""" | ||
prefix = f"model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value" | ||
w = state_dict[f"{prefix}.weight"].clone() | ||
b = state_dict[f"{prefix}.bias"].clone() | ||
return (w, b) | ||
|
||
|
||
def load_state_dict_cogview(cogview_path): | ||
""" | ||
Loads the CogView4 model from diffusers and returns its state_dict(). | ||
NOTE: You should adjust 'torch_dtype' and 'device_map' as appropriate. | ||
""" | ||
cogview_model = CogView4Transformer2DModel.from_pretrained( | ||
cogview_path, torch_dtype=torch.bfloat16, device_map="auto" | ||
) | ||
return cogview_model.state_dict() | ||
|
||
|
||
def extract_qkv_from_cogview(state_dict, layer_idx, num_heads, head_dim, hidden_dim): | ||
""" | ||
Extract Q, K, V from CogView4 checkpoint and reshape them into the same shape as SAT’s QKV. | ||
For each layer i: | ||
Q prefix: transformer_blocks.{layer_idx}.attn1.to_q | ||
K prefix: transformer_blocks.{layer_idx}.attn1.to_k | ||
V prefix: transformer_blocks.{layer_idx}.attn1.to_v | ||
Final shape must match SAT's [3*hidden_dim, hidden_dim] for weight, and [3*hidden_dim] for bias. | ||
""" | ||
q_prefix = f"transformer_blocks.{layer_idx}.attn1.to_q" | ||
k_prefix = f"transformer_blocks.{layer_idx}.attn1.to_k" | ||
v_prefix = f"transformer_blocks.{layer_idx}.attn1.to_v" | ||
|
||
# Extract | ||
q_weight = state_dict[f"{q_prefix}.weight"].clone() | ||
k_weight = state_dict[f"{k_prefix}.weight"].clone() | ||
v_weight = state_dict[f"{v_prefix}.weight"].clone() | ||
|
||
q_bias = state_dict[f"{q_prefix}.bias"].clone() | ||
k_bias = state_dict[f"{k_prefix}.bias"].clone() | ||
v_bias = state_dict[f"{v_prefix}.bias"].clone() | ||
|
||
# Reshape weights: [hidden_dim, hidden_dim] -> [num_heads, head_dim, hidden_dim] | ||
# Then concat along the first dimension (which will become 3*num_heads*head_dim) | ||
q_weight = q_weight.view(num_heads, head_dim, hidden_dim) | ||
k_weight = k_weight.view(num_heads, head_dim, hidden_dim) | ||
v_weight = v_weight.view(num_heads, head_dim, hidden_dim) | ||
|
||
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) # shape: (3*num_heads, head_dim, hidden_dim) | ||
qkv_weight = qkv_weight.view(3 * num_heads * head_dim, hidden_dim) # flatten | ||
|
||
# Reshape biases: [hidden_dim] -> [num_heads, head_dim] | ||
q_bias = q_bias.view(num_heads, head_dim) | ||
k_bias = k_bias.view(num_heads, head_dim) | ||
v_bias = v_bias.view(num_heads, head_dim) | ||
|
||
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) # (3*num_heads, head_dim) | ||
qkv_bias = qkv_bias.view(3 * num_heads * head_dim) | ||
|
||
return (qkv_weight, qkv_bias) | ||
|
||
def create_sat_state_dict_from_megatron(megatron_ckpt_dict, num_layers=48, num_heads=32, hidden_size=3072): | ||
""" | ||
Convert a loaded Megatron checkpoint's 'model' dictionary into the same | ||
format used by SAT. This returns something like {'module': {...}} for | ||
easy comparison with SAT. | ||
|
||
The code below is adapted from your 'create_sat_state_dict' function, | ||
but we rename it here to keep it direct. | ||
""" | ||
from tqdm import tqdm | ||
|
||
hidden_size_per_head = hidden_size // num_heads | ||
mega_weight = megatron_ckpt_dict["model"] | ||
sat_weight = {} | ||
|
||
# --- patch_embed --- | ||
sat_weight["model.diffusion_model.mixins.patch_embed.proj.weight"] = \ | ||
mega_weight["encoder_expand_linear.weight"].reshape(hidden_size, 64).clone() | ||
sat_weight["model.diffusion_model.mixins.patch_embed.proj.bias"] = \ | ||
mega_weight["encoder_expand_linear.bias"].clone() | ||
|
||
sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.weight"] = \ | ||
mega_weight["text_projector.weight"].clone() | ||
sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.bias"] = \ | ||
mega_weight["text_projector.bias"].clone() | ||
|
||
# --- time embedding --- | ||
sat_weight["model.diffusion_model.time_embed.0.weight"] = \ | ||
mega_weight["time_embedding.time_embed.0.weight"].clone() | ||
sat_weight["model.diffusion_model.time_embed.0.bias"] = \ | ||
mega_weight["time_embedding.time_embed.0.bias"].clone() | ||
sat_weight["model.diffusion_model.time_embed.2.weight"] = \ | ||
mega_weight["time_embedding.time_embed.2.weight"].clone() | ||
sat_weight["model.diffusion_model.time_embed.2.bias"] = \ | ||
mega_weight["time_embedding.time_embed.2.bias"].clone() | ||
|
||
# --- label embedding --- | ||
sat_weight["model.diffusion_model.label_emb.0.0.weight"] = \ | ||
mega_weight["label_embedding.label_embed.0.weight"].clone() | ||
sat_weight["model.diffusion_model.label_emb.0.0.bias"] = \ | ||
mega_weight["label_embedding.label_embed.0.bias"].clone() | ||
sat_weight["model.diffusion_model.label_emb.0.2.weight"] = \ | ||
mega_weight["label_embedding.label_embed.2.weight"].clone() | ||
sat_weight["model.diffusion_model.label_emb.0.2.bias"] = \ | ||
mega_weight["label_embedding.label_embed.2.bias"].clone() | ||
|
||
# --- layers --- | ||
for i in tqdm(range(num_layers), desc="Converting Megatron->SAT"): | ||
# attention output | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.weight"] = \ | ||
mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.weight"].clone() | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.bias"] = \ | ||
mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.bias"].clone() | ||
|
||
# QKV | ||
qkv_weight = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.weight"].clone() | ||
qkv_bias = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.bias"].clone() | ||
|
||
# Reshape QKV from Megatron format into SAT format | ||
# qkv_weight: [3*hidden_size, hidden_size] -> [num_heads, 3, hidden_size_per_head, hidden_size] -> ... | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.weight"] = \ | ||
qkv_weight.view(num_heads, 3, hidden_size_per_head, hidden_size) \ | ||
.permute(1, 0, 2, 3) \ | ||
.reshape(3 * hidden_size, hidden_size).clone() | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.bias"] = \ | ||
qkv_bias.view(num_heads, 3, hidden_size_per_head) \ | ||
.permute(1, 0, 2) \ | ||
.reshape(3 * hidden_size) \ | ||
.clone() | ||
|
||
# MLP | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.weight"] = \ | ||
mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.weight"].clone() | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.bias"] = \ | ||
mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.bias"].clone() | ||
|
||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.weight"] = \ | ||
mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.weight"].clone() | ||
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.bias"] = \ | ||
mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.bias"].clone() | ||
|
||
# AdaLN | ||
adaln_weight = mega_weight[f"decoder.layers.{i}.adaln.weight"].clone() | ||
adaln_bias = mega_weight[f"decoder.layers.{i}.adaln.bias"].clone() | ||
|
||
sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.weight"] = adaln_weight.clone() | ||
sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.bias"] = adaln_bias.clone() | ||
|
||
# --- final layers --- | ||
sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.weight"] = \ | ||
mega_weight["adaln_final.weight"].clone() | ||
sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.bias"] = \ | ||
mega_weight["adaln_final.bias"].clone() | ||
sat_weight["model.diffusion_model.mixins.final_layer.linear.weight"] = \ | ||
mega_weight["output_projector.weight"].clone() | ||
sat_weight["model.diffusion_model.mixins.final_layer.linear.bias"] = \ | ||
mega_weight["output_projector.bias"].clone() | ||
|
||
return OrderedDict(sat_weight) | ||
|
||
|
||
def load_state_dict_megatron_and_convert_to_sat(megatron_ckpt_path, num_layers, num_heads, hidden_size): | ||
""" | ||
Load a Megatron checkpoint from <megatron_ckpt_path>, then convert it into | ||
an SAT-style OrderedDict for direct QKV comparison. | ||
|
||
Typically, <megatron_ckpt_path> = ".../iter_0287500/mp_rank_00/model_optim_rng.pt" | ||
""" | ||
ckpt = torch.load(megatron_ckpt_path, map_location="cuda") | ||
# Convert to SAT | ||
sat_like_weight = create_sat_state_dict_from_megatron( | ||
ckpt, num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size | ||
) | ||
return sat_like_weight | ||
|
||
def compute_l2_difference(tensor1, tensor2): | ||
"""Compute L2 norm of the difference between two tensors.""" | ||
return torch.norm(tensor1 - tensor2, p=2).item() | ||
|
||
|
||
def compare_qkv(qkv1, qkv2, name1="Model1", name2="Model2", atol=1e-6): | ||
""" | ||
Compare QKV from two different sources (each is a tuple of (weight, bias)). | ||
Returns (weight_match, bias_match, weight_l2, bias_l2). | ||
""" | ||
w1, b1 = qkv1 | ||
w2, b2 = qkv2 | ||
|
||
weight_match = torch.allclose(w1, w2, atol=atol) | ||
bias_match = torch.allclose(b1, b2, atol=atol) | ||
weight_l2_diff = compute_l2_difference(w1, w2) | ||
bias_l2_diff = compute_l2_difference(b1, b2) | ||
|
||
if not (weight_match and bias_match): | ||
print(f"[QKV Mismatch] {name1} vs {name2}") | ||
print(f" Weight L2: {weight_l2_diff:.6f}, Bias L2: {bias_l2_diff:.6f}") | ||
else: | ||
# If everything matches well: | ||
print(f"[QKV Match] {name1} vs {name2} (Weight L2={weight_l2_diff:.6f}, Bias L2={bias_l2_diff:.6f})") | ||
|
||
return weight_match, bias_match, weight_l2_diff, bias_l2_diff | ||
|
||
if __name__ == "__main__": | ||
num_layers = 28 | ||
num_heads = 32 | ||
hidden_dim = 4096 | ||
head_dim = hidden_dim // num_heads | ||
|
||
sat_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_sat/0287500/mp_rank_00_model_states.pt" | ||
sat_state_dict = load_state_dict_sat(sat_ckpt_path) | ||
|
||
cogview_path = "/share/zyx/CogView4-6B-0128/transformer" # directory containing model index for diffusers | ||
cogview_state_dict = load_state_dict_cogview(cogview_path) | ||
|
||
megatron_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_ema/iter_0287500/mp_rank_00/model_optim_rng.pt" | ||
mega_as_sat_state_dict = load_state_dict_megatron_and_convert_to_sat( | ||
megatron_ckpt_path, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_size=hidden_dim | ||
) | ||
|
||
print("\n==== Start QKV Comparison ====\n") | ||
for layer_idx in range(num_layers): | ||
print(f"--- Layer {layer_idx} ---") | ||
|
||
# Extract QKV from SAT | ||
sat_qkv = extract_qkv_from_sat(sat_state_dict, layer_idx) | ||
|
||
# Extract QKV from CogView | ||
cogview_qkv = extract_qkv_from_cogview( | ||
cogview_state_dict, layer_idx, num_heads, head_dim, hidden_dim | ||
) | ||
|
||
# Extract QKV from Megatron->SAT | ||
mega_qkv = extract_qkv_from_sat(mega_as_sat_state_dict, layer_idx) | ||
|
||
# Compare: SAT vs CogView | ||
compare_qkv(sat_qkv, cogview_qkv, name1="SAT", name2="CogView4") | ||
|
||
# Compare: SAT vs Megatron | ||
compare_qkv(sat_qkv, mega_qkv, name1="SAT", name2="Megatron") | ||
|
||
# Compare: CogView vs Megatron (optional) | ||
compare_qkv(cogview_qkv, mega_qkv, name1="CogView4", name2="Megatron") | ||
|
||
print() | ||
|
||
print("=== Done ===") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# CogView4Transformer2DModel | ||
|
||
A Diffusion Transformer model for 2D data from [CogView4]() | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import CogView3PlusTransformer2DModel | ||
zRzRzRzRzRzRzR marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") | ||
zRzRzRzRzRzRzR marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
## CogView4Transformer2DModel | ||
|
||
[[autodoc]] CogView4Transformer2DModel | ||
|
||
## Transformer2DModelOutput | ||
|
||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
--> | ||
|
||
# CogView4 | ||
|
||
<Tip> | ||
|
||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. | ||
|
||
</Tip> | ||
|
||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). | ||
|
||
## CogView4Pipeline | ||
|
||
[[autodoc]] CogView4Pipeline | ||
- all | ||
- __call__ | ||
|
||
## CogView4PipelineOutput | ||
|
||
[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.