Skip to content

Add support for SmallThinker model series #14898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 28, 2025
Merged

Conversation

wdl339
Copy link
Contributor

@wdl339 wdl339 commented Jul 27, 2025

Purpose

SmallThinker is a family of on-device native Mixture-of-Experts (MoE) language models specially designed for local deployment, co-developed by the IPADS (the team behind the high-speed inference framework PowerInfer) and School of AI at Shanghai Jiao Tong University and Zenergize AI. Designed from the ground up for resource-constrained environments, SmallThinker brings powerful, private, and low-latency AI directly to your personal devices, without relying on the cloud.

This PR is to add support for the SmallThinker series of models to llama.cpp.

Modifications

  • Add support for SmallthinkerForCausalLM model conversion in convert-hf-to-gguf.py.
  • Add new LLM_ARCH_SMALLTHINKER architecture.
  • Add support for inference for models based on LLM_ARCH_SMALLTHINKER.
  • Implement a new function build_moe_ffn_from_probs , to handle SmallThinker's unique architecture where the MoE router is positioned before the attention block.
  • Implement a new function set_dense_start_swa_pattern . While the existing set_swa_pattern function enables a pattern where every Nth layer is dense, starting the count from SWA layers, the new function allows the pattern to start with a dense layer.

Testing

Clone the model from https://huggingface.co/PowerInfer/SmallThinker-4BA0.6B-Instruct and use convert-hf-to-gguf.py to convert to gguf format.
(Edit: We have https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF now)

./build/bin/llama-cli -m /mnt/m2_4/wdl/smallthinker-4b.gguf -p "The meaning of life is" -n 64

The meaning of life is a profound and deeply personal question! There is no single answer, and different perspectives offer varying insights. Here are some major approaches to understanding it:

1. **Existential Meaning-Making**  
   Philosophers like Sartre argue life has no inherent meaning—*we create our own purpose

full output
./build/bin/llama-cli -m /mnt/m2_4/wdl/smallthinker-4b.gguf -p "The meaning of life is" -n 64
build: 6006 (92b518b4) with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 31 key-value pairs and 323 tensors from /mnt/m2_4/wdl/smallthinker-4b.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = smallthinker
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = 4b_v7
llama_model_loader: - kv   3:                           general.finetune str              = 4b_v7
llama_model_loader: - kv   4:                         general.size_label str              = 32x758M
llama_model_loader: - kv   5:                   smallthinker.block_count u32              = 32
llama_model_loader: - kv   6:                smallthinker.context_length u32              = 32768
llama_model_loader: - kv   7:              smallthinker.embedding_length u32              = 1536
llama_model_loader: - kv   8:          smallthinker.attention.head_count u32              = 12
llama_model_loader: - kv   9:       smallthinker.attention.head_count_kv u32              = 2
llama_model_loader: - kv  10:                smallthinker.rope.freq_base f32              = 1500000.000000
llama_model_loader: - kv  11: smallthinker.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  12:          smallthinker.attention.key_length u32              = 128
llama_model_loader: - kv  13:        smallthinker.attention.value_length u32              = 128
llama_model_loader: - kv  14:                          general.file_type u32              = 1
llama_model_loader: - kv  15:                  smallthinker.expert_count u32              = 32
llama_model_loader: - kv  16:             smallthinker.expert_used_count u32              = 4
llama_model_loader: - kv  17:    smallthinker.expert_feed_forward_length u32              = 768
llama_model_loader: - kv  18:           smallthinker.feed_forward_length u32              = 768
llama_model_loader: - kv  19:            smallthinker.expert_gating_func u32              = 2
llama_model_loader: - kv  20:               general.quantization_version u32              = 2
llama_model_loader: - kv  21:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  22:                         tokenizer.ggml.pre str              = qwen2
llama_model_loader: - kv  23:                      tokenizer.ggml.tokens arr[str,151936]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  24:                  tokenizer.ggml.token_type arr[i32,151936]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  25:                      tokenizer.ggml.merges arr[str,151387]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 151645
llama_model_loader: - kv  27:            tokenizer.ggml.padding_token_id u32              = 151643
llama_model_loader: - kv  28:                tokenizer.ggml.bos_token_id u32              = 151643
llama_model_loader: - kv  29:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  30:                    tokenizer.chat_template str              = {%- if tools %}\n    {{- '<|im_start|>...
llama_model_loader: - type  f32:   97 tensors
llama_model_loader: - type  f16:  226 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = F16
print_info: file size   = 7.95 GiB (16.01 BPW) 
load: special tokens cache size = 26
load: token to piece cache size = 0.9311 MB
print_info: arch             = smallthinker
print_info: vocab_only       = 0
print_info: n_ctx_train      = 32768
print_info: n_embd           = 1536
print_info: n_layer          = 32
print_info: n_head           = 12
print_info: n_head_kv        = 2
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 6
print_info: n_embd_k_gqa     = 256
print_info: n_embd_v_gqa     = 256
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 768
print_info: n_expert         = 32
print_info: n_expert_used    = 4
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 2
print_info: rope scaling     = linear
print_info: freq_base_train  = 1500000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 32768
print_info: rope_finetuned   = unknown
print_info: model type       = 4B
print_info: model params     = 4.27 B
print_info: general.name     = 4b_v7
print_info: n_ff_exp         = 768
print_info: expert_gating_func   = sigmoid
print_info: vocab type       = BPE
print_info: n_vocab          = 151936
print_info: n_merges         = 151387
print_info: BOS token        = 151643 '<|endoftext|>'
print_info: EOS token        = 151645 '<|im_end|>'
print_info: EOT token        = 151645 '<|im_end|>'
print_info: PAD token        = 151643 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 151659 '<|fim_prefix|>'
print_info: FIM SUF token    = 151661 '<|fim_suffix|>'
print_info: FIM MID token    = 151660 '<|fim_middle|>'
print_info: FIM PAD token    = 151662 '<|fim_pad|>'
print_info: FIM REP token    = 151663 '<|repo_name|>'
print_info: FIM SEP token    = 151664 '<|file_sep|>'
print_info: EOG token        = 151643 '<|endoftext|>'
print_info: EOG token        = 151645 '<|im_end|>'
print_info: EOG token        = 151662 '<|fim_pad|>'
print_info: EOG token        = 151663 '<|repo_name|>'
print_info: EOG token        = 151664 '<|file_sep|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size =  8144.63 MiB
............................................................................................
llama_context: constructing llama_context
llama_context: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: kv_unified    = true
llama_context: freq_base     = 1500000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (32768) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.58 MiB
llama_kv_cache_unified:        CPU KV buffer size =   128.00 MiB
llama_kv_cache_unified: size =  128.00 MiB (  4096 cells,  32 layers,  1/ 1 seqs), K (f16):   64.00 MiB, V (f16):   64.00 MiB
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility
llama_context:        CPU compute buffer size =   299.75 MiB
llama_context: graph nodes  = 1799
llama_context: graph splits = 1
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|im_end|> logit bias = -inf
common_init_from_params: added <|fim_pad|> logit bias = -inf
common_init_from_params: added <|repo_name|> logit bias = -inf
common_init_from_params: added <|file_sep|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 8
main: chat template is available, enabling conversation mode (disable it with -no-cnv)
*** User-specified prompt will pre-start conversation, did you mean to set --system-prompt (-sys) instead?
main: chat template example:
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant


system_info: n_threads = 8 (n_threads_batch = 8) / 32 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

main: interactive mode on.
sampler seed: 32068148
sampler params: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = 64, n_keep = 0

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.
 - Not using system message. To change it, set a different value via -sys PROMPT

user
The meaning of life is
assistant
The meaning of life is a profound and deeply personal question! There is no single answer, and different perspectives offer varying insights. Here are some major approaches to understanding it:

1. **Existential Meaning-Making**  
   Philosophers like Sartre argue life has no inherent meaning—*we create our own purpose

@github-actions github-actions bot added the python python script changes label Jul 27, 2025
@wdl339 wdl339 marked this pull request as ready for review July 27, 2025 17:38
@@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
return moe_out;
}

ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code duplication is unfortunate, is it possible to merge this into build_moe_ffn with probs as a toggle without making too much of a mess?

Can be a follow-up.

Copy link
Contributor Author

@wdl339 wdl339 Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point. I've been thinking about the best way to merge these and have a couple of ideas on how we could approach it.

  1. As you suggested, we could modify build_moe_ffn to accept an optional probs parameter. The main difficulty here is that the logic for weight normalization and activation functions diverges significantly between the two paths, so it would require some careful internal branching to keep it clean.
  2. Alternatively, we could extract the initial router logic (logits and probs calculation) into its own function. build_moe_ffn would then have a check at the beginning to decide whether to call this new router function. My main concern with this approach is that build_moe_ffn is a core function, and I'm a bit worried about affecting other models, so this would need careful testing.

Both approaches seem feasible. Given the complexity and your suggestion that this can be a follow-up, would you prefer I handle this in a separate PR, or should I proceed with one of these solutions here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is probably best.

Copy link
Contributor Author

@wdl339 wdl339 Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll open a follow-up issue later to track this so we don't forget. Thanks!

Edit: I've created issue #14920 to track this.

wdl339 and others added 5 commits July 28, 2025 08:37
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
@CISC CISC merged commit 6c6e397 into ggml-org:master Jul 28, 2025
50 checks passed
@arch-btw
Copy link
Contributor

@wdl339 I'm running into an issue where it's outputting random numbers in the output. Also, sometimes after a comma it forgets to add a space.

numbers

@CISC
Copy link
Collaborator

CISC commented Jul 28, 2025

@wdl339 I'm running into an issue where it's outputting random numbers in the output. Also, sometimes after a comma it forgets to add a space.

@arch-btw FWIW I'm unable to reproduce this on CUDA.

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 28, 2025

@wdl339 I'm running into an issue where it's outputting random numbers in the output. Also, sometimes after a comma it forgets to add a space.

Hi @arch-btw, thanks for reporting this. I was unable to reproduce this behavior on the CPU backend. Could you please provide the specific model file you're using, the full command you ran to get this output and the hardware you are running on ?

As a side note, if you are using one of the quantized 4BA0.6B models, if your hardware allows, you could try using an unquantized version to see if that improves the output quality. This would also help us narrow down if the problem is specific to quantization.

@arch-btw
Copy link
Contributor

Hi @wdl339 , I'm on CPU. I converted the gguf's myself. According to the readme here, the lm_head.pt might have been missing, could that be it? It only happens on certain prompts. I converted to f32 first and then to Q4_K_M.

Used this command for inference:

./llama-cli -m SmallThinker-64x2.2B-21BA3B-Instruct-Q4_K_M.gguf --conversation --device none

I also tried it with --jinja, but that has the same issue for me.

I'll try to run an unquantized version. Thanks!

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 29, 2025

Hi @wdl339 , I'm on CPU. I converted the gguf's myself. According to the readme here, the lm_head.pt might have been missing, could that be it? It only happens on certain prompts. I converted to f32 first and then to Q4_K_M.

Used this command for inference:

./llama-cli -m SmallThinker-64x2.2B-21BA3B-Instruct-Q4_K_M.gguf --conversation --device none

I also tried it with --jinja, but that has the same issue for me.

I'll try to run an unquantized version. Thanks!

Thanks for the detailed follow-up and for testing! Let me address your points:

  1. Regarding the lm_head.pt file: This file is optional, it's not required here. The sparse lm_head feature is not yet supported in llama.cpp, so its absence is expected and is not the cause of this issue.
  2. Regarding the random output: I was able to reproduce the issue you described (random numbers) just one time when running without a chat template. It happens at a very low frequency, I believe this is not a bug in the core llama.cpp implementation.

My strong recommendation is to use a chat template, as it is an instruction-tuned model. I'm not sure if --jinja uses a suitable template. You can find the official template definition in the tokenizer_config.json on the model's Hub page: tokenizer_config.json
. To make it easier for you, I've converted it into the .jinja format (using Gemini AI). You can save the content below into a file (e.g., smallthinker.jinja) and use it with the --chat-template-file argument.

The .jinja template
{%- if tools -%}
    {{- '<|im_start|>system\n' -}}
    {%- if messages[0]['role'] == 'system' -%}
        {{- messages[0]['content'] -}}
    {%- else -%}
        {{- 'You are SmallThinker. You are a helpful assistant.' -}}
    {%- endif -%}
    {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" -}}
    {%- for tool in tools -%}
        {{- "\n" -}}
        {{- tool | tojson -}}
    {%- endfor -%}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" -}}
{%- else -%}
    {%- if messages[0]['role'] == 'system' -%}
        {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' -}}
    {%- else -%}
        {{- '<|im_start|>system\nYou are SmallThinker. You are a helpful assistant.<|im_end|>\n' -}}
    {%- endif -%}
{%- endif -%}
{%- for message in messages -%}
    {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%}
        {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' -}}
    {%- elif message.role == "assistant" -%}
        {{- '<|im_start|>' + message.role -}}
        {%- if message.content -%}
            {{- '\n' + message.content -}}
        {%- endif -%}
        {%- for tool_call in message.tool_calls -%}
            {%- if tool_call.function is defined -%}
                {%- set tool_call = tool_call.function -%}
            {%- endif -%}
            {{- '\n<tool_call>\n{\"name\": \"' -}}
            {{- tool_call.name -}}
            {{- '\", \"arguments\": ' -}}
            {{- tool_call.arguments | tojson -}}
            {{- '}\n</tool_call>' -}}
        {%- endfor -%}
        {{- '<|im_end|>\n' -}}
    {%- elif message.role == "tool" -%}
        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%}
            {{- '<|im_start|>user' -}}
        {%- endif -%}
        {{- '\n<tool_response>\n' -}}
        {{- message.content -}}
        {{- '\n</tool_response>' -}}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%}
            {{- '<|im_end|>\n' -}}
        {%- endif -%}
    {%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
    {{- '<|im_start|>assistant\n' -}}
{%- endif -%}

When I use this template, the issue disappears completely and the output is stable.

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

My strong recommendation is to use a chat template, as it is an instruction-tuned model. I'm not sure if --jinja uses a suitable template. You can find the official template definition in the tokenizer_config.json on the model's Hub page: tokenizer_config.json .

With --jinja it will use that chat template.

Without it will use an identical one, but without a default system prompt (define one with -sys), you can compare by using --verbose-prompt. :)

@TechnotechGit
Copy link

TechnotechGit commented Jul 29, 2025

I have been getting random number outputs on the CUDA backend too, using llama-server and the IQ4_XS quant from PowerInfer/SmallThinker-21BA3B-Instruct-GGUF.
I can probably provide an example output, but I only observed these outputs after a few thousand tokens of context were filled, maybe 2000-4000 tokens in context.

Edit: it seems to happen when the output and context crosses about 4000 tokens. I can post a full chat, but here's an example of what it looks like when it goes from normal to broken output.

### **3. Example Workflow**
1. **Profile the model**:
   ```bash
   py-spy record -o profile.log ./your_script.py
   ```
2. **Visualize results**:
   ```bash
   py-spy flamegraph -o profile.log
   # ConvertG -o -w -f profile1 ./profile
    # Open the script: py-spy-spy-spy-spi $(py -f profiled
   ```
   ```python38356999
   ```3864
   2. 2009
   2. 1
   12. 0000000075
   2. .93
   ```
30003000
   # Example:
   ```
380000023545.1
   ```
3000000000100000000002.0005.0000101335600
  ```
3
 4.000057
   ```
35500000
 
 1
  # Profile the model_response.py-spy-spy-spy-spy-spy-spy-spy-spy-spy50055599999005563
 GGGGGGGGGGGGGGGGGggggggggggg [continues...]

@arch-btw
Copy link
Contributor

@wdl339 @CISC @TechnotechGit

Thank you guys for the help and the report!

@wdl339 I saved the template and ran it with --chat-template-file, unfortunately, the issue persists. I also tried to use GPU (AMD) and the issue remains.

Just to summarize:

  • It happens on both CPU and GPU (AMD and CUDA)
  • It happens with --jinja, --chat-template-file (the same as --jinja as @CISC mentioned), and also without any prompt flags
  • For @TechnotechGit it only happens after ~4000 tokens
  • For me it happens right away even with tiny prompts

One thing I noticed in @TechnotechGit 's output was this:

GGGGGGGGGGGGGGGGGggggggggggg

I remember @ggerganov mentioned something about that recently, I think he said it was a bug that was seen in the GLM models as well, but I might be wrong. I'll try to find his comment on this.

@arch-btw
Copy link
Contributor

I didn't find ggerganov's comment but I did find his commit on that 'GGgggggg' issue:

#8412

And also it was discussed here:

#8031 (comment)

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 29, 2025

With --jinja it will use that chat template.

Without it will use an identical one, but without a default system prompt (define one with -sys), you can compare by using --verbose-prompt. :)

Thanks for the --verbose-prompt tip. That's very helpful context for me.

Just to summarize:

  • It happens on both CPU and GPU (AMD and CUDA)
  • It happens with --jinja, --chat-template-file (the same as --jinja as @CISC mentioned), and also without any prompt flags
  • For @TechnotechGit it only happens after ~4000 tokens
  • For me it happens right away even with tiny prompts

@arch-btw Thank you for that summary. I'm currently focused on reproducing the issue, to determine if the root cause lies within the model itself, my llama.cpp implementation, or the quantization process. This behavior is quite puzzling, since we have not yet seen similar reports on the PowerInfer framework or the Hugging Face Hub, it might take some time to investigate. We appreciate your patience and help.

One of my teammates mentioned that our models may perform better with a slightly lower temperature setting. Maybe you can also have a try.

Edit: We tested your prompt on a CUDA backend using both the F16 and Q4_K_M versions, and we were unable to reproduce the issue in either case. We cannot reproduce this problem on Intel CPU backend, too. This leads me to wonder if there might be a subtle precision issue happening specifically with the AMD CPU backend.

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 29, 2025

I have been getting random number outputs on the CUDA backend too, using llama-server and the IQ4_XS quant from PowerInfer/SmallThinker-21BA3B-Instruct-GGUF. I can probably provide an example output, but I only observed these outputs after a few thousand tokens of context were filled, maybe 2000-4000 tokens in context.

Edit: it seems to happen when the output and context crosses about 4000 tokens. I can post a full chat, but here's an example of what it looks like when it goes from normal to broken output.

Hi @TechnotechGit, thank you for your report. We ran a test on our end using the same IQ4_XS quantized model on a CUDA backend. We specifically tested with a long context, filling it up to 8K tokens, we haven't been able to reproduce the problem yet. It might take some time to investigate deeper.

9f62f213bfac8405a4ad27b94583835 eba450f8a68e643f8629e109708fcc0

@TechnotechGit
Copy link

One of my teammates mentioned that our models may perform better with a slightly lower temperature setting. Maybe you can also have a try.

I am not @arch-btw but I did try the model with temp 0.3 and other sampling settings at default; same result.

I did some more testing; it seems to be flash attention causing this from what I can tell. Removing q8_0 kv cache did not change the result, as did removing other parameters until I removed flash attention. All attempts before that failed at roughly 4200 tokens in context (including the failed output), but the no flash-attn run was fine. It might still go weird later but I haven't tested that far.

@arch-btw
Copy link
Contributor

arch-btw commented Jul 29, 2025

I tried to run it in full verbose mode, here is part of the output:

 Howevereval: [ ' However':4354 ]
n_past = 45
n_remain: -47
,eval: [ ',':11 ]
n_past = 46
n_remain: -48
1eval: [ '1':16 ]
n_past = 47
n_remain: -49
 thereeval: [ ' there':1052 ]
n_past = 48
n_remain: -50
 areeval: [ ' are':525 ]

Which is displayed without verbose as: However,1 there are

I also noticed this in log:

0
load_tensors: tensor 'token_embd.weight' (q4_K) (and 522 others) cannot be used with preferred buffer type Vulkan_Host, using CPU instead

Thinking it might have been the vulkan part of my build, I tried the prebuild non-vulkan version (llama-b6026-bin-ubuntu-x64.zip) but unfortunately, had the same issue.

I don't have flash attention enabled. I'm not sure what to try but feel free to shoot ideas.

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 30, 2025

I did some more testing; it seems to be flash attention causing this from what I can tell. Removing q8_0 kv cache did not change the result, as did removing other parameters until I removed flash attention. All attempts before that failed at roughly 4200 tokens in context (including the failed output), but the no flash-attn run was fine. It might still go weird later but I haven't tested that far.

Thank you for digging into this and sharing your findings. Our 21B model uses sliding window attention and the window size is exactly 4096. Your finding makes me wonder if there's a potential incompatibility or an edge case between the Flash Attention implementation and SWA. Fortunately, as far as I know, llama.cpp does not enable Flash Attention by default. We are very happy to hear that your issue is resolved.

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 30, 2025

Which is displayed without verbose as: However,1 there are

@arch-btw, good news! we've successfully reproduced the issue on our end.

We followed your exact conversion path: first converting the base model to fp32, and then quantizing it to Q4_K_M. Using a command with --temp 0 to eliminate randomness (./build/bin/llama-cli -m st21b-f32-q4km.gguf --conversation --device none --temp 0), we were able to confirm the appearance of unrelated numbers in the output.

af9e6ad4b3caf2decdc2224bc242e47

Then, we tried a different conversion path. When we convert the model to fp16 first and then quantize it to Q4_K_M, the problem with the random numbers disappears.

image

Regarding the missing space after a comma, I've noticed this tends to happen at the beginning of the model's reply, so I suspect this might be a minor artifact from the original training data.

So, we have two concrete recommendations for you:

  1. Change your conversion method: If you prefer to convert the GGUF files yourself, please use fp16 as the intermediate format before quantizing.
  2. Use the official GGUF files: You can find them here: https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF . We will also be releasing QAT (Quantization-Aware Training) versions of our models, which offer further optimizations for quantization accuracy. The output of 21B QAT model is also correct.
image

Thank you again for your incredible help in debugging this.

@arch-btw
Copy link
Contributor

@wdl339 I am glad to hear that, that solves it and thank you to you and your team for being so patient with debugging!

It's funny, a few seconds before reading your comment I came across this:

#14939 (comment)

It seems to be an issue with that model too.

@CISC
Copy link
Collaborator

CISC commented Jul 30, 2025

@wdl339 That makes no sense (unless FP16 slightly corrupts the model and masks the issue), original weights are BF16, which is lossless (though wasting space) when converted to FP32, so that should not cause any issues.

I always convert BF16 to BF16 for lossless and non-wasteful conversion, and if possible do inference in BF16 as well.

Edit: As already mentioned I have not been able to reproduce any issues here (using BF16), and quite contrary to your statement I would not be surprised if converting to FP16 would cause issues.

@arch-btw
Copy link
Contributor

arch-btw commented Jul 31, 2025

I may have spoken too soon, upon further inspection, the bug appears to still be there. It only appears with certain prompts.

Edit: I will try the pre-quantized models, just to rule all of that out.

@CISC
Copy link
Collaborator

CISC commented Jul 31, 2025

Edit: I will try the pre-quantized models, just to rule all of that out.

Try bartowski's, it looks like they are only ones quantized from BF16.

@arch-btw
Copy link
Contributor

Thanks CISC! Unfortunately, at the moment bartowski only has the small one: https://huggingface.co/bartowski/PowerInfer_SmallThinker-4BA0.6B-Instruct-GGUF.
However, it doesn't happen with that one (Q4_K_M.gguf)

When I tried the one by PowerInfer, it does give me the same bug: https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF/blob/main/SmallThinker-21B-A3B-Instruct.Q4_K.gguf

I've also converted my own BF16 quant to Q4_K_M and it also has the same issue. Just to be sure that it's not my hardware, I tried some models by other companies, but those run fine.

I found another prompt that triggers it almost every single time:

Write a Python program to find common elements in two lists

Output:

To solve this problem,1 we need to find the common....

With --temp 0:

To find common elements between two lists in Python,3 we can use....

@wdl339
Copy link
Contributor Author

wdl339 commented Jul 31, 2025

When I tried the one by PowerInfer, it does give me the same bug: https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF/blob/main/SmallThinker-21B-A3B-Instruct.Q4_K.gguf

I've also converted my own BF16 quant to Q4_K_M and it also has the same issue. Just to be sure that it's not my hardware, I tried some models by other companies, but those run fine.

I found another prompt that triggers it almost every single time:

Write a Python program to find common elements in two lists

Output:

To solve this problem,1 we need to find the common....

Hi @arch-btw, thank you for your persistence and for finding that reliable prompt. You are right, the issue is still there, and we've been able to reproduce it on our side.

image

2. Use the official GGUF files: You can find them here: https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct-GGUF . We will also be releasing QAT (Quantization-Aware Training) versions of our models, which offer further optimizations for quantization accuracy. The output of 21B QAT model is also correct.

The good news is, we have a definitive solution. We recently released QAT versions of the model. We tested your prompt using this specific QAT model: official GGUF repo. The output is clean.

image

At this point, I can only conclude that the problem is rooted in quantization. The deep, underlying reason for this specific artifact is not yet fully clear, but the QAT approach effectively solves it. We highly recommend using the QAT version to ensure the best performance. Thank you again for your invaluable help.

@arch-btw
Copy link
Contributor

arch-btw commented Aug 1, 2025

@wdl339 thank you again! The QAT quant does indeed work, I triple checked this time 😅

Output:

To find common elements between two lists in Python, we can efficiently achieve this using **set operations**, which are optimized for such tasks. Here's a step-by-step solution:....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants