-
Notifications
You must be signed in to change notification settings - Fork 12.6k
Add support for SmallThinker model series #14898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn( | |||
return moe_out; | |||
} | |||
|
|||
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code duplication is unfortunate, is it possible to merge this into build_moe_ffn
with probs
as a toggle without making too much of a mess?
Can be a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point. I've been thinking about the best way to merge these and have a couple of ideas on how we could approach it.
- As you suggested, we could modify
build_moe_ffn
to accept an optionalprobs
parameter. The main difficulty here is that the logic for weight normalization and activation functions diverges significantly between the two paths, so it would require some careful internal branching to keep it clean. - Alternatively, we could extract the initial router logic (logits and probs calculation) into its own function.
build_moe_ffn
would then have a check at the beginning to decide whether to call this new router function. My main concern with this approach is thatbuild_moe_ffn
is a core function, and I'm a bit worried about affecting other models, so this would need careful testing.
Both approaches seem feasible. Given the complexity and your suggestion that this can be a follow-up, would you prefer I handle this in a separate PR, or should I proceed with one of these solutions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate PR is probably best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll open a follow-up issue later to track this so we don't forget. Thanks!
Edit: I've created issue #14920 to track this.
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@wdl339 I'm running into an issue where it's outputting random numbers in the output. Also, sometimes after a comma it forgets to add a space. ![]() |
Hi @arch-btw, thanks for reporting this. I was unable to reproduce this behavior on the CPU backend. Could you please provide the specific model file you're using, the full command you ran to get this output and the hardware you are running on ? As a side note, if you are using one of the quantized 4BA0.6B models, if your hardware allows, you could try using an unquantized version to see if that improves the output quality. This would also help us narrow down if the problem is specific to quantization. |
Hi @wdl339 , I'm on CPU. I converted the gguf's myself. According to the readme here, the lm_head.pt might have been missing, could that be it? It only happens on certain prompts. I converted to f32 first and then to Q4_K_M. Used this command for inference:
I also tried it with I'll try to run an unquantized version. Thanks! |
Thanks for the detailed follow-up and for testing! Let me address your points:
My strong recommendation is to use a chat template, as it is an instruction-tuned model. I'm not sure if The .jinja template
When I use this template, the issue disappears completely and the output is stable. |
With Without it will use an identical one, but without a default system prompt (define one with |
I have been getting random number outputs on the CUDA backend too, using llama-server and the IQ4_XS quant from PowerInfer/SmallThinker-21BA3B-Instruct-GGUF. Edit: it seems to happen when the output and context crosses about 4000 tokens. I can post a full chat, but here's an example of what it looks like when it goes from normal to broken output.
|
Thank you guys for the help and the report! @wdl339 I saved the template and ran it with Just to summarize:
One thing I noticed in @TechnotechGit 's output was this:
I remember @ggerganov mentioned something about that recently, I think he said it was a bug that was seen in the GLM models as well, but I might be wrong. I'll try to find his comment on this. |
I didn't find ggerganov's comment but I did find his commit on that 'GGgggggg' issue: And also it was discussed here: |
Thanks for the
@arch-btw Thank you for that summary. I'm currently focused on reproducing the issue, to determine if the root cause lies within the model itself, my llama.cpp implementation, or the quantization process. This behavior is quite puzzling, since we have not yet seen similar reports on the PowerInfer framework or the Hugging Face Hub, it might take some time to investigate. We appreciate your patience and help. One of my teammates mentioned that our models may perform better with a slightly lower temperature setting. Maybe you can also have a try. Edit: We tested your prompt on a CUDA backend using both the F16 and Q4_K_M versions, and we were unable to reproduce the issue in either case. We cannot reproduce this problem on Intel CPU backend, too. This leads me to wonder if there might be a subtle precision issue happening specifically with the AMD CPU backend. |
Hi @TechnotechGit, thank you for your report. We ran a test on our end using the same IQ4_XS quantized model on a CUDA backend. We specifically tested with a long context, filling it up to 8K tokens, we haven't been able to reproduce the problem yet. It might take some time to investigate deeper. ![]() ![]() |
I am not @arch-btw but I did try the model with temp 0.3 and other sampling settings at default; same result. I did some more testing; it seems to be flash attention causing this from what I can tell. Removing q8_0 kv cache did not change the result, as did removing other parameters until I removed flash attention. All attempts before that failed at roughly 4200 tokens in context (including the failed output), but the no flash-attn run was fine. It might still go weird later but I haven't tested that far. |
I tried to run it in full verbose mode, here is part of the output:
Which is displayed without verbose as: I also noticed this in log:
Thinking it might have been the vulkan part of my build, I tried the prebuild non-vulkan version (llama-b6026-bin-ubuntu-x64.zip) but unfortunately, had the same issue. I don't have flash attention enabled. I'm not sure what to try but feel free to shoot ideas. |
Thank you for digging into this and sharing your findings. Our 21B model uses sliding window attention and the window size is exactly 4096. Your finding makes me wonder if there's a potential incompatibility or an edge case between the Flash Attention implementation and SWA. Fortunately, as far as I know, llama.cpp does not enable Flash Attention by default. We are very happy to hear that your issue is resolved. |
@arch-btw, good news! we've successfully reproduced the issue on our end. We followed your exact conversion path: first converting the base model to fp32, and then quantizing it to Q4_K_M. Using a command with --temp 0 to eliminate randomness (./build/bin/llama-cli -m st21b-f32-q4km.gguf --conversation --device none --temp 0), we were able to confirm the appearance of unrelated numbers in the output. ![]() Then, we tried a different conversion path. When we convert the model to fp16 first and then quantize it to Q4_K_M, the problem with the random numbers disappears. ![]() Regarding the missing space after a comma, I've noticed this tends to happen at the beginning of the model's reply, so I suspect this might be a minor artifact from the original training data. So, we have two concrete recommendations for you:
![]() Thank you again for your incredible help in debugging this. |
@wdl339 I am glad to hear that, that solves it and thank you to you and your team for being so patient with debugging! It's funny, a few seconds before reading your comment I came across this: It seems to be an issue with that model too. |
@wdl339 That makes no sense (unless FP16 slightly corrupts the model and masks the issue), original weights are BF16, which is lossless (though wasting space) when converted to FP32, so that should not cause any issues. I always convert BF16 to BF16 for lossless and non-wasteful conversion, and if possible do inference in BF16 as well. Edit: As already mentioned I have not been able to reproduce any issues here (using BF16), and quite contrary to your statement I would not be surprised if converting to FP16 would cause issues. |
I may have spoken too soon, upon further inspection, the bug appears to still be there. It only appears with certain prompts. Edit: I will try the pre-quantized models, just to rule all of that out. |
Try bartowski's, it looks like they are only ones quantized from BF16. |
Thanks CISC! Unfortunately, at the moment bartowski only has the small one: https://huggingface.co/bartowski/PowerInfer_SmallThinker-4BA0.6B-Instruct-GGUF. When I tried the one by PowerInfer, it does give me the same bug: https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF/blob/main/SmallThinker-21B-A3B-Instruct.Q4_K.gguf I've also converted my own BF16 quant to Q4_K_M and it also has the same issue. Just to be sure that it's not my hardware, I tried some models by other companies, but those run fine. I found another prompt that triggers it almost every single time:
Output:
With --temp 0:
|
Hi @arch-btw, thank you for your persistence and for finding that reliable prompt. You are right, the issue is still there, and we've been able to reproduce it on our side. ![]()
The good news is, we have a definitive solution. We recently released QAT versions of the model. We tested your prompt using this specific QAT model: official GGUF repo. The output is clean. ![]() At this point, I can only conclude that the problem is rooted in quantization. The deep, underlying reason for this specific artifact is not yet fully clear, but the QAT approach effectively solves it. We highly recommend using the QAT version to ensure the best performance. Thank you again for your invaluable help. |
@wdl339 thank you again! The QAT quant does indeed work, I triple checked this time 😅 Output:
|
Purpose
SmallThinker is a family of on-device native Mixture-of-Experts (MoE) language models specially designed for local deployment, co-developed by the IPADS (the team behind the high-speed inference framework PowerInfer) and School of AI at Shanghai Jiao Tong University and Zenergize AI. Designed from the ground up for resource-constrained environments, SmallThinker brings powerful, private, and low-latency AI directly to your personal devices, without relying on the cloud.
This PR is to add support for the SmallThinker series of models to llama.cpp.
Modifications
build_moe_ffn_from_probs
, to handle SmallThinker's unique architecture where the MoE router is positioned before the attention block.set_dense_start_swa_pattern
. While the existingset_swa_pattern
function enables a pattern where every Nth layer is dense, starting the count from SWA layers, the new function allows the pattern to start with a dense layer.Testing
Clone the model from https://huggingface.co/PowerInfer/SmallThinker-4BA0.6B-Instruct and use
convert-hf-to-gguf.py
to convert to gguf format.(Edit: We have https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF now)
full output