Skip to content

Add support for SmallThinker model series #14898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7589,6 +7589,88 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
self.gguf_writer.add_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
if (self.hparams.get('moe_primary_router_apply_softmax')):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
# YaRN is not enabled by default
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])

sliding_window_layout = self.hparams.get("sliding_window_layout")
if sliding_window_layout:
for i in sliding_window_layout:
if i != 0:
sliding_window = self.hparams.get("sliding_window_size")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
break

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down", "gate", "up"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

###### CONVERSION LOGIC ######


Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum):
SMOLLM3 = auto()
LFM2 = auto()
DREAM = auto()
SMALLTHINKER = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -695,6 +696,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -2483,6 +2485,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
],
MODEL_ARCH.SMALLTHINKER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.router", # llama4 jamba
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand Down Expand Up @@ -362,6 +363,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w12", # neobert
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand All @@ -372,6 +374,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
),

MODEL_TENSOR.FFN_UP_SHEXP: (
Expand Down Expand Up @@ -401,6 +404,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand All @@ -410,6 +414,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
),

MODEL_TENSOR.FFN_GATE_SHEXP: (
Expand Down Expand Up @@ -448,6 +453,7 @@ class TensorNameMap:
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w3", # neobert
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand All @@ -459,6 +465,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand Down
22 changes: 22 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1933,6 +1934,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{
LLM_ARCH_SMALLTHINKER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
},
},
{
LLM_ARCH_DREAM,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ enum llm_arch {
LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_UNKNOWN,
};

Expand Down
94 changes: 94 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
return moe_out;
}

ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code duplication is unfortunate, is it possible to merge this into build_moe_ffn with probs as a toggle without making too much of a mess?

Can be a follow-up.

Copy link
Contributor Author

@wdl339 wdl339 Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point. I've been thinking about the best way to merge these and have a couple of ideas on how we could approach it.

  1. As you suggested, we could modify build_moe_ffn to accept an optional probs parameter. The main difficulty here is that the logic for weight normalization and activation functions diverges significantly between the two paths, so it would require some careful internal branching to keep it clean.
  2. Alternatively, we could extract the initial router logic (logits and probs calculation) into its own function. build_moe_ffn would then have a check at the beginning to decide whether to call this new router function. My main concern with this approach is that build_moe_ffn is a core function, and I'm a bit worried about affecting other models, so this would need careful testing.

Both approaches seem feasible. Given the complexity and your suggestion that this can be a follow-up, would you prefer I handle this in a separate PR, or should I proceed with one of these solutions here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is probably best.

Copy link
Contributor Author

@wdl339 wdl339 Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll open a follow-up issue later to track this so we don't forget. Thanks!

Edit: I've created issue #14920 to track this.

ggml_tensor * cur,
ggml_tensor * probs,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llama_expert_gating_func_type gating_op,
int il) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];

// add experts selection bias - introduced in DeepSeek V3
// leave probs unbiased as it's later used to get expert weights
ggml_tensor * selection_probs = probs;
if (exp_probs_b != nullptr) {
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
cb(selection_probs, "ffn_moe_probs_biased", il);
}

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);

ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);

weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
weights = ggml_soft_max(ctx0, weights);
} else {
weights = ggml_sigmoid(ctx0, weights);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);

weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
}

weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);

cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);

ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);

ggml_tensor * experts = nullptr;
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);

cur = ggml_reglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_reglu", il);

experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);

experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);

ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };

assert(n_expert_used > 0);

// order the views before the adds
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);

ggml_build_forward_expand(gf, cur_experts[i]);
}

// aggregate experts
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
// to avoid potentially a large number of add nodes during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
ggml_tensor * moe_out = cur_experts[0];

for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
}

if (n_expert_used == 1) {
// avoid returning a non-contiguous tensor
moe_out = ggml_cont(ctx0, moe_out);
}

cb(moe_out, "ffn_moe_out", il);

return moe_out;
}

// input embeddings with optional lora
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
const int64_t n_embd = hparams.n_embd;
Expand Down
12 changes: 12 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,18 @@ struct llm_graph_context {
llama_expert_gating_func_type gating_op,
int il) const;

ggml_tensor * build_moe_ffn_from_probs(
ggml_tensor * cur,
ggml_tensor * probs,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llama_expert_gating_func_type gating_op,
int il) const;

//
// inputs
//
Expand Down
12 changes: 9 additions & 3 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

#include "ggml.h"

void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
}
} else {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}
}

Expand Down
Loading
Loading