-
Notifications
You must be signed in to change notification settings - Fork 12.2k
model : add hunyuan moe #14425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
model : add hunyuan moe #14425
Conversation
for token, rank in mergeable_ranks.items(): | ||
vocab[QwenModel.token_bytes_to_string(token)] = rank | ||
if len(token) == 1: | ||
continue | ||
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) | ||
if len(merged) == 2: | ||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quite doubt if this is correct. if someone knows or has time to do tokenizer test, please feel free to leave a comment
Ok, getting somewhere now. The model runs, but output gibberish
|
Thanks for working on this! I got the same looking output trying The only odd things I noticed were:
Tested on an AMD 7965WX 24x Core 256GB DDR5@4800 + Dual RTX A6000 (96GB Total VRAM) rig. 👈 a few more commands and logs fwiwconvertpython \
convert_hf_to_gguf.py \
--outtype bf16 \
--split-max-size 50G \
--outfile /mnt/raid/models/ubergarm/Hunyuan-A13B-Instruct-GGUF/ \
/mnt/raid/models/tencent/Hunyuan-A13B-Instruct/
... llama-servermodel=/mnt/raid/models/ubergarm/Hunyuan-A13B-Instruct-GGUF/Hunyuan-A13B-Instruct-BF16-00001-of-00004.gguf
./build/bin/llama-server \
--model "$model" \
-fa \
-ctk f16 -ctv f16 \
-c 8192 \
-ts 48,48 \
-ngl 10 \
--threads 24 \
--host 127.0.0.1 \
--port 8080
... client>>> User:
Tell a funny joke in English.
>>> Assistant:
[UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧] |
I don't know as much about this as you guys but, could it be that the tokenizer is splitting characters like 新 ("new") into raw bytes? So the UTF-8 sequence And so the fragments get wrapped in Because common Chinese characters always use 3 bytes in UTF-8:
It matches the error: |
The cgraph is still not correct. Testing with this tiny random weight: https://huggingface.co/ngxson/hunyuan-moe-tiny-random/tree/main Seems like the problem is from the self-attention block |
I don't know if the improvements I am seeing are from your last The changes I made were:
my edits are here: https://github.com/kooshi/llama.cpp/tree/hunyuan
|
The more looking at the upstream implementation, the more I wonder if it actually works. My Mac M3 Ultra can't load the original model even though having 512GB of RAM. Now, testing with the tiny weight. Switching between Also, And more importantly, If that is true, it means they messed up badly this time. |
https://www.diffchecker.com/P3e0hQM5/ https://huggingface.co/tencent/Tencent-Hunyuan-Large/blob/main/Hunyuan-A52B-Instruct/ And https://www.diffchecker.com/P9FIR5OD/ In other words, its almost Hunyuan large? I'm not sure why the HF attention implementations would be bugged. But other reimplementations like vllm's seem to work, so maybe they can shed some light on this: |
I take that back, apparently vllm is only sometimes working with A13B, heh: |
I had the original model from Huggingface work coherently on pure CPU. It uses the HunYuanSdpaAttention codepath. This is all tentative as I just got it running at all: If I compare logits for a single-token prompt, I get a very similar logit distribution from both llama.cpp and the HF. More than one token and things look different. I'm purely going with numerical token IDs for llama.cpp as the tokenizer is messed up as observed (I tried 'a' the token 64 for single-token prompt and '12' prompt (16, 17) for two-token test, e.g. This is with the code from combined @ngxson and @kooshi with the .gguf made with @kooshi 's code (I took latest efforts I saw here in the discussion to start off). Below in the dropbox is the My machine has 256GB of memory, a Hetzner server with a modern AMD EPYC CPU. I do have a Mac Studio (M2, 192GB) as well but for CPU work this Hetzner is usually much faster. (I don't know why asking it to use bfloat16 helps, maybe it doesn't make giant copies of tensors or something when you ask it to do that; it's just something I observed and never checked what's it doing behind the scenes). test.pyThis is a version of the example code from the Huggingface page that I modified a bit. #!/usr/bin/env python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import re
def main():
with torch.no_grad():
model_path = '/home/shannon/llama.cpp/tencent_Hunyuan-A13B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
messages = [
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt",
enable_thinking=True # Toggle thinking mode (default: True)
)
outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=20)
output_text = tokenizer.decode(outputs[0])
print(outputs)
print(output_text)
if __name__ == '__main__':
main() stdout of test.pyThe output has output as token IDs and as text (two prints()) in there. To run this, you need to install
I'm on and off this weekend trying to also figure out where computation graph is off exactly. If I find out before someone else does, I'll let you all know. (Runs surprisingly fast on transformers+CPU, I'm used to that combo being extraordinarily slow. It is still very slow, just not like "it will take 30 minutes to make 10 tokens" slow). |
Is it possible to load this model in 4-bit precision using Transformers? Does bitsandbytes support this model? I’m limited to a total of 72GB of VRAM across several GPUs, so bfloat16 won’t work for me. |
Their official inference script for running the int4 quant on vllm is using (still didn't work for me though) |
To add to @ubergarm options, I did notice there are some quantized versions like https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8 or https://huggingface.co/tencent/Hunyuan-A13B-Instruct-GPTQ-Int4 (they look like they are designed to work with The GPTQ-Int4 one has a single Haven't tried any of them. For computation graph work feels better to get whatever is highest precision I am able to run conveniently. |
If someone can run it, could you please verify if |
@ngxson is this the part you wanted to see if it's None or not? Argument to the forward()? ![]() Edit: took a bigger screenshot to show more clearly where I put that. Stdout tail because that first paste is cut off, I see
Edit2: I'm going to let this thing generate a full response which might take a while. But I feel this might be a bit short as a test; it almost verbatim mentions the prompt in the <think> so maybe it's about to repeat itself or something. I'll paste as a new comment when it's done. Just want to get more confirmation the HF implementation itself works beyond very short generations. |
Full response example of the stdout from test2.py (I cut off all the parts that said attention mask is None)
Code is almost same as before, pasting for reproducibility: test2.py#!/usr/bin/env python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import re
def main():
with torch.no_grad():
model_path = '/home/shannon/llama.cpp/tencent_Hunyuan-A13B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
messages = [
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt",
enable_thinking=True # Toggle thinking mode (default: True)
)
outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=5000)
output_text = tokenizer.decode(outputs[0])
print(outputs)
print(output_text)
if __name__ == '__main__':
main() The output looks normal to me and it answered the prompt. It does look like to me it works. CPU-only, 256GB Hetzner server. |
Ok so based on investigations above, it seems like I'm now getting 100% logits match between llama.cpp <> sdpa using the random weight. But not sure why the official weight still doesn't work correctly. |
I pulled that last commit to the amalgamation I had and I'm getting <think> tags and some coherence this time. Tokenization is still messed up, so I'll turn my attention to what's going on over there. Maybe the computation graph itself is now fine but we'll find out. I don't know if you tried to ran official weights on a Metal GPU but my experience has been that there's been a whole lot of bugs in the MPS/Metal backend in torch. Things like |
Small correction: I only got logits matched for the first 2 tokens. From 3rd token, things start to go crazy. So definitely problem with RoPE Meme taken from my blog post: ![]() |
That exact code snippet also worked for me on a B200 GPU (with |
I've been messing with tokenization all day. in the same way Qwen does, but that obliterated the coherence I had seen before. I'll pull the new cpp code and keep digging. |
RoPE is fixed. However, new problem appear: It seems like some engineers at Tencent think that they should make their top-k MoE selection a bit "special" And by "special", I mean this block of code, which seems to be harmless at first. In short, what is does is to keep track of the usage for each expert. If an expert is used too much (i.e. exceed capacity), it will be "de-prioritized". I assume this is to fix the problem where MoE router is extremely hard to train (ref: Qwen3MoE has some "unused" experts) Sounds like a good idea, but this is extremely difficult to reimplement in llama.cpp This also makes the number of experts used by a given token become uneven. Some tokens will use less experts than the other, some use no experts (due to the priority explained above). That sounds good on the surface, but the actual implementation always calculate fixed number of experts per token - which defeat the whole point. I'm now confident that Tencent messed up this time. |
Woo! Tokenization is working, and I get a whopping 75 tps in q4_0 on my gpus.
Edit: It still devolves into repetition pretty quickly though 😔 |
Fix #14415
STILL WIP
TODO: