Skip to content

model : add hunyuan moe #14425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,6 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)
Expand Down Expand Up @@ -815,6 +814,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -6390,6 +6392,125 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])


@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
undo_permute = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_vocab(self):
self._set_vocab_gpt2()

def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

# merge logic is copied from QwenModel, maybe incorrect
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
Comment on lines +6418 to +6424
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quite doubt if this is correct. if someone knows or has time to do tokenizer test, please feel free to leave a comment

self.gguf_writer.add_token_merges(merges)

reverse_vocab = tokenizer.decoder
assert max(reverse_vocab.keys()) < tokenizer.vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)
added_vocab = tokenizer.get_added_vocab()

added_tokens_decoder = tokenizer.added_tokens_decoder

for i in range(tokenizer.vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")

if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
# NOTE: this was added for Gemma.
# Encoding and decoding the tokens above isn't sufficient for this case.
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)

return tokens, toktypes, tokpre

def set_gguf_parameters(self):
super().set_gguf_parameters()

self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["intermediate_size"])

moe_intermediate_size = self.hparams["moe_intermediate_size"]
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])

moe_topk = self.hparams["moe_topk"]
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(moe_topk[0])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

tensors: list[tuple[str, Tensor]] = []

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))

return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
]


Expand Down
23 changes: 23 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE = auto()
DOTS1 = auto()
ARCEE = auto()
HUNYUAN_MOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -654,6 +655,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -2177,6 +2179,27 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand Down Expand Up @@ -362,6 +363,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),

# AWQ-activation gate
Expand Down Expand Up @@ -398,6 +400,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),

# Feed-forward down
Expand Down Expand Up @@ -447,11 +450,13 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),

MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
Expand All @@ -461,6 +466,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
};

enum llama_rope_type {
Expand Down
25 changes: 25 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1658,6 +1659,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down
18 changes: 18 additions & 0 deletions src/llama-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
};

llm_chat_template llm_chat_template_from_str(const std::string & name) {
Expand Down Expand Up @@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
Expand Down Expand Up @@ -665,6 +668,21 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|response|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
if (add_ass) {
ss << "<|startoftext|>";
}
} else {
// template not supported
return -1;
Expand Down
1 change: 1 addition & 0 deletions src/llama-chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

Expand Down
7 changes: 7 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);

if (arch == LLM_ARCH_HUNYUAN_MOE) {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_expert_used, n_tokens]
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [1, n_tokens]
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_scaled", il);
}

if (norm_w) {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);

Expand Down
Loading