From efe27eb5cf9eb0252f4a06485674606e5c5f031d Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Tue, 8 Jul 2025 05:09:10 +0000 Subject: [PATCH 01/16] support smallthinker --- convert_hf_to_gguf.py | 132 ++++++++++++++++++++++++ gguf-py/gguf/constants.py | 20 ++++ gguf-py/gguf/tensor_mapping.py | 7 ++ src/llama-arch.cpp | 22 ++++ src/llama-arch.h | 1 + src/llama-graph.cpp | 131 ++++++++++++++++++++++++ src/llama-graph.h | 15 +++ src/llama-hparams.cpp | 6 ++ src/llama-hparams.h | 16 ++- src/llama-model.cpp | 181 +++++++++++++++++++++++++++++++++ 10 files changed, 530 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dd80a4a05d596..22d0d9219945f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6535,6 +6535,138 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + +@ModelBase.register("SmallthinkerForCausalLM") +class SmallthinkerModel(TextModel): + model_arch = gguf.MODEL_ARCH.SMALLTHINKER + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: + self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) + logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") + # YaRN is not enabled by default + # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + sliding_window = self.hparams.get("sliding_window") + self.gguf_writer.add_sliding_window(sliding_window) + + intermediate_size = self.hparams.get("ffn_hidden_size") + moe_intermediate_size = self.hparams.get("moe_ffn_hidden_size") + moe_layer_layout = self.hparams.get("moe_layer_layout") + ffn_layout = [] + for i, layout in enumerate(moe_layer_layout): + if layout == 0: + ffn_layout.append(intermediate_size) + elif layout == 1: + ffn_layout.append(moe_intermediate_size) + else: + raise ValueError(f"Unknown moe layer layout: {layout}") + self.gguf_writer.add_feed_forward_length(ffn_layout) + # def add_feed_forward_length(self, length: int | Sequence[int]) -> None: + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down", "gate", "up"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + added_tokens_decoder = tokenizer.added_tokens_decoder + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token: str = reverse_vocab[i] + if token in added_vocab: + # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. + # To avoid unexpected issues - we make sure to normalize non-normalized tokens + if not added_tokens_decoder[i].normalized: + previous_token = token + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + if previous_token != token: + logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + + if added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + # NOTE: this was added for Gemma. + # Encoding and decoding the tokens above isn't sufficient for this case. + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) + + return tokens, toktypes, tokpre + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c12609c6d9f99..9e709f5c6fbc2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -357,6 +357,7 @@ class MODEL_ARCH(IntEnum): DOTS1 = auto() ARCEE = auto() ERNIE4_5 = auto() + SMALLTHINKER = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -660,6 +661,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ERNIE4_5: "ernie4_5", + MODEL_ARCH.SMALLTHINKER: "smallthinker", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2211,6 +2213,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.SMALLTHINKER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 51634ef6bdd2e..29f27df24acbd 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -303,6 +303,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.feed_forward.router", # llama4 "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -346,6 +347,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 "transformer_encoder.{bid}.ffn.w12", # neobert + "model.layers.{bid}.block_sparse_moe.up", # smallthinker ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -356,6 +358,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -383,6 +386,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone "model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.block_sparse_moe.gate", # smallthinker ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -392,6 +396,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -429,6 +434,7 @@ class TensorNameMap: "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 "transformer_encoder.{bid}.ffn.w3", # neobert + "model.layers.{bid}.block_sparse_moe.down", # smallthinker ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -440,6 +446,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab24054305857..1881afce2135d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -78,6 +78,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1694,6 +1695,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_SMALLTHINKER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index b769831dff5ec..cbcf0a03d78e4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -82,6 +82,7 @@ enum llm_arch { LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, + LLM_ARCH_SMALLTHINKER, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7f0e8c67f1325..6af0551cba70d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -797,6 +797,137 @@ ggml_tensor * llm_graph_context::build_moe_ffn( return moe_out; } +ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( + ggml_tensor * cur, + ggml_tensor * probs, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + int il) const { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN + + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx0, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + if (norm_w) { + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + } + if (scale_w) { + weights = ggml_scale(ctx0, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + if (weight_before_ffn) { + // repeat cur to [n_embd, n_expert_used, n_tokens] + ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); + cur = ggml_mul(ctx0, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * experts = nullptr; + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + switch (type_op) { + case LLM_FFN_SILU: + if (gate_exps) { + cur = ggml_swiglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_swiglu", il); + } else { + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + if (gate_exps) { + cur = ggml_geglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_geglu", il); + } else { + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_moe_gelu", il); + } break; + case LLM_FFN_RELU: + if (gate_exps) { + cur = ggml_reglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_reglu", il); + } else { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_moe_relu", il); + } break; + default: + GGML_ABORT("fatal error"); + } + + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + if (!weight_before_ffn) { + experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx0, moe_out); + } + + cb(moe_out, "ffn_moe_out", il); + + return moe_out; +} + // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index 7bdf656768a0c..2e84d81ad1c3d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -563,6 +563,21 @@ struct llm_graph_context { llama_expert_gating_func_type gating_op, int il) const; + ggml_tensor * build_moe_ffn_from_probs( + ggml_tensor * cur, + ggml_tensor * probs, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + int il) const; + // // inputs // diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51b901..663db0a12884b 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -8,6 +8,12 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern) { } } +void llama_hparams::set_dense_start_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + } +} + bool llama_hparams::is_swa_any() const { for (uint32_t il = 0; il < n_layer; ++il) { if (swa_layers[il]) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 476d0a5eade28..255f06e0ace73 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -138,7 +138,7 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; - // llama4 + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; @@ -172,6 +172,20 @@ struct llama_hparams { // etc ... void set_swa_pattern(uint32_t n_pattern); + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: dense + // il == 1: swa + // il == 2: swa + // il == 3: dense + // il == 4: swa + // il == 5: swa + // il == 6: dense + // etc ... + void set_dense_start_swa_pattern(uint32_t n_pattern); + // return true if one of the layers is SWA bool is_swa_any() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0573c5bcea0a4..4f9c206149667 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1549,6 +1549,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SMALLTHINKER: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_dense_start_swa_pattern(4); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + default: + type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4475,6 +4490,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_SMALLTHINKER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_ff_cur = hparams.n_ff_arr[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for SMALLTHINKER"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for SMALLTHINKER"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + if (n_ff_exp == n_ff_cur) { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff_cur }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff_cur,n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff_cur }, 0); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -4795,6 +4852,10 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } + if (arch == LLM_ARCH_SMALLTHINKER) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + vocab.print_info(); } @@ -14660,6 +14721,121 @@ struct llm_build_arcee : public llm_graph_context { } }; +struct llm_build_smallthinker : public llm_graph_context{ + llm_build_smallthinker(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params){ + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * probs = nullptr; + bool is_moe = hparams.n_ff_exp == hparams.n_ff_arr[il]; + + if (is_moe) { + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_probs", il); + } + + // norm + cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if(il % hparams.n_no_rope_layer_step) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, + nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * ffn_out = nullptr; + if (is_moe) { + ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_RELU, true, false, 0.0, il); + + } else { + ffn_out = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_RELU, LLM_FFN_PAR, il); + } + + cb(ffn_out, "ffn_out", il); + cur = ffn_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -15040,6 +15216,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_SMALLTHINKER: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -15228,6 +15408,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: + case LLM_ARCH_SMALLTHINKER: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: From a6d6eafe06ca3eee23bfdba0ef50ef8644db13c1 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Thu, 24 Jul 2025 09:24:33 +0000 Subject: [PATCH 02/16] support 20b softmax, 4b no sliding window --- convert_hf_to_gguf.py | 14 +++++++++++++- src/llama-model.cpp | 44 +++++++++++++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 22d0d9219945f..3299b0709b353 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6552,6 +6552,10 @@ def set_gguf_parameters(self): if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") + if (self.hparams.get('moe_primary_router_apply_softmax')): + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) # YaRN is not enabled by default # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts rope_scaling = self.hparams.get("rope_scaling") or {} @@ -6559,8 +6563,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + sliding_window = self.hparams.get("sliding_window") - self.gguf_writer.add_sliding_window(sliding_window) + sliding_window_layout = self.hparams.get("sliding_window_layout") + if sliding_window and sliding_window_layout: + for i in sliding_window_layout: + if i != 0: + self.gguf_writer.add_sliding_window(sliding_window) + break + elif sliding_window: + self.gguf_writer.add_sliding_window(sliding_window) intermediate_size = self.hparams.get("ffn_hidden_size") moe_intermediate_size = self.hparams.get("moe_ffn_hidden_size") diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4f9c206149667..6fabcedb56cda 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1551,17 +1551,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_SMALLTHINKER: { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - hparams.set_dense_start_swa_pattern(4); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_dense_start_swa_pattern(4); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; + } + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); switch (hparams.n_layer) { - default: - type = LLM_TYPE_UNKNOWN; + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; default: throw std::runtime_error("unsupported model architecture"); @@ -4854,6 +4862,7 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_SMALLTHINKER) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } vocab.print_info(); @@ -14736,7 +14745,12 @@ struct llm_build_smallthinker : public llm_graph_context{ // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + llm_graph_input_i * inp_attn = nullptr; + if (hparams.is_swa_any()) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -14747,7 +14761,11 @@ struct llm_build_smallthinker : public llm_graph_context{ ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { + probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] + } else { + probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + } cb(probs, "ffn_moe_probs", il); } @@ -14782,8 +14800,13 @@ struct llm_build_smallthinker : public llm_graph_context{ cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, - nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + if (hparams.is_swa_any()) { + cur = build_attn(static_cast(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, + nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } else { + cur = build_attn(static_cast(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, + nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } } if (il == n_layer - 1) { @@ -14791,6 +14814,7 @@ struct llm_build_smallthinker : public llm_graph_context{ ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + if (probs != nullptr) { probs = ggml_get_rows(ctx0, probs, inp_out_ids); } } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); From 8e2cb21fb2ebe5dbf4fdedc1482b5c262ec940ed Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Thu, 24 Jul 2025 12:49:51 +0000 Subject: [PATCH 03/16] new build_moe_ffn_from_probs, and can run 4b --- src/llama-graph.cpp | 97 ++++++++++++++------------------------------- src/llama-graph.h | 5 +-- src/llama-model.cpp | 30 +++++++------- 3 files changed, 45 insertions(+), 87 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4df39d0f1f58d..1b9cc4aec0632 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -947,14 +947,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, - llm_ffn_op_type type_op, - bool norm_w, - bool scale_w, - float w_scale, + llama_expert_gating_func_type gating_op, int il) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; - const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN // add experts selection bias - introduced in DeepSeek V3 // leave probs unbiased as it's later used to get expert weights @@ -973,90 +969,57 @@ ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - if (norm_w) { - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); - + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { + weights = ggml_soft_max(ctx0, weights); + } else { + weights = ggml_sigmoid(ctx0, weights); ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); - - weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); - } - if (scale_w) { - weights = ggml_scale(ctx0, weights, w_scale); - cb(weights, "ffn_moe_weights_scaled", il); } - cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); - if (weight_before_ffn) { - // repeat cur to [n_embd, n_expert_used, n_tokens] - ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); - cur = ggml_mul(ctx0, repeated, weights); - cb(cur, "ffn_moe_weighted", il); - } + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate", il); - } else { - cur = up; - } + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); - switch (type_op) { - case LLM_FFN_SILU: - if (gate_exps) { - cur = ggml_swiglu_split(ctx0, cur, up); - cb(cur, "ffn_moe_swiglu", il); - } else { - cur = ggml_silu(ctx0, cur); - cb(cur, "ffn_moe_silu", il); - } break; - case LLM_FFN_GELU: - if (gate_exps) { - cur = ggml_geglu_split(ctx0, cur, up); - cb(cur, "ffn_moe_geglu", il); - } else { - cur = ggml_gelu(ctx0, cur); - cb(cur, "ffn_moe_gelu", il); - } break; - case LLM_FFN_RELU: - if (gate_exps) { - cur = ggml_reglu_split(ctx0, cur, up); - cb(cur, "ffn_moe_reglu", il); - } else { - cur = ggml_relu(ctx0, cur); - cb(cur, "ffn_moe_relu", il); - } break; - default: - GGML_ABORT("fatal error"); - } + cur = ggml_reglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_reglu", il); experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - if (!weight_before_ffn) { - experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); + + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; + + assert(n_expert_used > 0); + + // order the views before the adds + for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { + cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); + + ggml_build_forward_expand(gf, cur_experts[i]); } // aggregate experts - ggml_tensor * moe_out = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); + // note: here we explicitly use hparams.n_expert_used instead of n_expert_used + // to avoid potentially a large number of add nodes during warmup + // ref: https://github.com/ggml-org/llama.cpp/pull/14753 + ggml_tensor * moe_out = cur_experts[0]; - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - } + for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { + moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); } if (n_expert_used == 1) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 30ccb20c43c86..d4d565754c500 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -634,10 +634,7 @@ struct llm_graph_context { ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, - llm_ffn_op_type type_op, - bool norm_w, - bool scale_w, - float w_scale, + llama_expert_gating_func_type gating_op, int il) const; // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6e523e455a480..762119319cbd2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5191,7 +5191,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -17078,7 +17083,7 @@ struct llm_build_lfm2 : public llm_graph_context { }; struct llm_build_smallthinker : public llm_graph_context{ - llm_build_smallthinker(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params){ + llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -17105,15 +17110,8 @@ struct llm_build_smallthinker : public llm_graph_context{ bool is_moe = hparams.n_ff_exp == hparams.n_ff_arr[il]; if (is_moe) { - ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] - cb(logits, "ffn_moe_logits", il); - - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { - probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] - } else { - probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] - } - cb(probs, "ffn_moe_probs", il); + probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + cb(probs, "ffn_moe_logits", il); } // norm @@ -17148,10 +17146,10 @@ struct llm_build_smallthinker : public llm_graph_context{ cb(Kcur, "Kcur", il); if (hparams.is_swa_any()) { - cur = build_attn(static_cast(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, + cur = build_attn(static_cast(inp_attn), model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } else { - cur = build_attn(static_cast(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, + cur = build_attn(static_cast(inp_attn), model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } } @@ -17175,8 +17173,8 @@ struct llm_build_smallthinker : public llm_graph_context{ if (is_moe) { ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - nullptr, n_expert, n_expert_used, LLM_FFN_RELU, true, false, 0.0, il); - + nullptr, n_expert, n_expert_used, + static_cast(hparams.expert_gating_func), il); } else { ffn_out = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_RELU, LLM_FFN_PAR, il); @@ -17647,7 +17645,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_SMALLTHINKER: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; default: GGML_ABORT("fatal error"); From e28d2c56bc27c213e2c939ed9139b07c9d9574a5 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Fri, 25 Jul 2025 05:26:31 +0000 Subject: [PATCH 04/16] fix 4b rope bug --- convert_hf_to_gguf.py | 28 +++++----------------------- src/llama-model.cpp | 17 +++++------------ 2 files changed, 10 insertions(+), 35 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bb754c578c66f..b02744705cb53 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7590,8 +7590,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter -@ModelBase.register("SmallthinkerForCausalLM") -class SmallthinkerModel(TextModel): +@ModelBase.register("SmallThinkerForCausalLM") +class SmallThinkerModel(TextModel): model_arch = gguf.MODEL_ARCH.SMALLTHINKER def set_gguf_parameters(self): @@ -7602,10 +7602,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_used_count(n_experts_used) if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + self.gguf_writer.add_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") - if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: - self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) - logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") if (self.hparams.get('moe_primary_router_apply_softmax')): self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: @@ -7618,29 +7616,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - sliding_window = self.hparams.get("sliding_window") sliding_window_layout = self.hparams.get("sliding_window_layout") - if sliding_window and sliding_window_layout: + if sliding_window_layout: for i in sliding_window_layout: if i != 0: + sliding_window = self.hparams.get("sliding_window_size") self.gguf_writer.add_sliding_window(sliding_window) break - elif sliding_window: - self.gguf_writer.add_sliding_window(sliding_window) - - intermediate_size = self.hparams.get("ffn_hidden_size") - moe_intermediate_size = self.hparams.get("moe_ffn_hidden_size") - moe_layer_layout = self.hparams.get("moe_layer_layout") - ffn_layout = [] - for i, layout in enumerate(moe_layer_layout): - if layout == 0: - ffn_layout.append(intermediate_size) - elif layout == 1: - ffn_layout.append(moe_intermediate_size) - else: - raise ValueError(f"Unknown moe layer layout: {layout}") - self.gguf_writer.add_feed_forward_length(ffn_layout) - # def add_feed_forward_length(self, length: int | Sequence[int]) -> None: _experts: list[dict[str, Tensor]] | None = None diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 762119319cbd2..4d48967f1907a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5200,7 +5200,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - const int64_t n_ff_cur = hparams.n_ff_arr[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); @@ -5220,16 +5219,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp; - if (n_ff_exp == n_ff_cur) { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - } else { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff_cur }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff_cur,n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff_cur }, 0); - } + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); } } break; default: @@ -17134,7 +17127,7 @@ struct llm_build_smallthinker : public llm_graph_context{ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - if(il % hparams.n_no_rope_layer_step) { + if(hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); From 8c6af02a9c810b91e3191822608ce96ec7f44d95 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Sun, 27 Jul 2025 06:18:57 +0000 Subject: [PATCH 05/16] fix python type check --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b02744705cb53..c8ecd7c8dbd12 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7621,7 +7621,8 @@ def set_gguf_parameters(self): for i in sliding_window_layout: if i != 0: sliding_window = self.hparams.get("sliding_window_size") - self.gguf_writer.add_sliding_window(sliding_window) + if sliding_window: + self.gguf_writer.add_sliding_window(sliding_window) break _experts: list[dict[str, Tensor]] | None = None From 92b518b4e880fe5a3f6a728ba72133c73caeb950 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Sun, 27 Jul 2025 07:38:41 +0000 Subject: [PATCH 06/16] remove is_moe judge --- src/llama-model.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3907daf63f633..1c31e7af1969e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17103,12 +17103,9 @@ struct llm_build_smallthinker : public llm_graph_context{ for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; ggml_tensor * probs = nullptr; - bool is_moe = hparams.n_ff_exp == hparams.n_ff_arr[il]; - if (is_moe) { - probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] - cb(probs, "ffn_moe_logits", il); - } + probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + cb(probs, "ffn_moe_logits", il); // norm cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -17165,16 +17162,10 @@ struct llm_build_smallthinker : public llm_graph_context{ cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - ggml_tensor * ffn_out = nullptr; - if (is_moe) { - ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, + ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, static_cast(hparams.expert_gating_func), il); - } else { - ffn_out = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_RELU, LLM_FFN_PAR, il); - } cb(ffn_out, "ffn_out", il); cur = ffn_out; From 4186babce2d4c2b56987eaae00806d8317d0952c Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Sun, 27 Jul 2025 16:01:15 +0000 Subject: [PATCH 07/16] remove set_dense_start_swa_pattern function and modify set_swa_pattern function --- src/llama-hparams.cpp | 18 +++++++++--------- src/llama-hparams.h | 19 ++++++------------- src/llama-model.cpp | 2 +- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 36ea4283ea861..7a06368dcda68 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,15 +2,15 @@ #include "ggml.h" -void llama_hparams::set_swa_pattern(uint32_t n_pattern) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); - } -} - -void llama_hparams::set_dense_start_swa_pattern(uint32_t n_pattern) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); +void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } } } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index ba5dfff917b20..d57aa3858b23d 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -161,9 +161,10 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA // if n_pattern == 1, all layers are dense - // example: n_pattern = 3 + // example 1: n_pattern = 3, dense_first = false // il == 0: swa // il == 1: swa // il == 2: dense @@ -172,21 +173,13 @@ struct llama_hparams { // il == 5: dense // il == 6: swa // etc ... - void set_swa_pattern(uint32_t n_pattern); - - // this value n_pattern means that every nth layer is dense (i.e. non-SWA) - // note that if n_pattern == 0, all layers are SWA - // if n_pattern == 1, all layers are dense - // example: n_pattern = 3 + // example 2: n_pattern = 2, dense_first = true // il == 0: dense // il == 1: swa - // il == 2: swa - // il == 3: dense - // il == 4: swa - // il == 5: swa - // il == 6: dense + // il == 2: dense + // il == 3: swa // etc ... - void set_dense_start_swa_pattern(uint32_t n_pattern); + void set_swa_pattern(uint32_t n_pattern, bool dense_first = false); // return true if one of the layers is SWA bool is_swa_any() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1c31e7af1969e..916c32aa2f48a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1775,7 +1775,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; - hparams.set_dense_start_swa_pattern(4); + hparams.set_swa_pattern(4, true); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; From f1d4698f7a32c9e5e94fbca714f141bc46946ce6 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Sun, 27 Jul 2025 16:26:21 +0000 Subject: [PATCH 08/16] trim trailing whitespace --- convert_hf_to_gguf.py | 5 ++--- src/llama-model.cpp | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c8ecd7c8dbd12..60fb23e5de35f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7589,7 +7589,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] - @ModelBase.register("SmallThinkerForCausalLM") class SmallThinkerModel(TextModel): model_arch = gguf.MODEL_ARCH.SMALLTHINKER @@ -7615,7 +7614,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - + sliding_window_layout = self.hparams.get("sliding_window_layout") if sliding_window_layout: for i in sliding_window_layout: @@ -7715,7 +7714,7 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens.append(token) return tokens, toktypes, tokpre - + ###### CONVERSION LOGIC ###### diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 916c32aa2f48a..8db5c54c7beaf 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17126,7 +17126,7 @@ struct llm_build_smallthinker : public llm_graph_context{ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - + if(hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); From f10cd4676ac9d0a56fa6f7053c0002eb19d2b8c9 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 08:37:11 +0800 Subject: [PATCH 09/16] remove get_vocab_base of SmallThinkerModel in convert_hf_to_gguf.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 44 ------------------------------------------- 1 file changed, 44 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 60fb23e5de35f..209f935963aba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7671,50 +7671,6 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - def get_vocab_base(self) -> tuple[list[str], list[int], str]: - tokens: list[str] = [] - toktypes: list[int] = [] - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) - assert max(tokenizer.vocab.values()) < vocab_size - - tokpre = self.get_vocab_base_pre(tokenizer) - - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} - added_vocab = tokenizer.get_added_vocab() - - added_tokens_decoder = tokenizer.added_tokens_decoder - - for i in range(vocab_size): - if i not in reverse_vocab: - tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.UNUSED) - else: - token: str = reverse_vocab[i] - if token in added_vocab: - # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. - # To avoid unexpected issues - we make sure to normalize non-normalized tokens - if not added_tokens_decoder[i].normalized: - previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) - if previous_token != token: - logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") - - if added_tokens_decoder[i].special or self.does_token_look_special(token): - toktypes.append(gguf.TokenType.CONTROL) - else: - # NOTE: this was added for Gemma. - # Encoding and decoding the tokens above isn't sufficient for this case. - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces - toktypes.append(gguf.TokenType.USER_DEFINED) - else: - toktypes.append(gguf.TokenType.NORMAL) - tokens.append(token) - - return tokens, toktypes, tokpre - ###### CONVERSION LOGIC ###### From 4af8b591d8847811c6104a4113b494c4fd4de523 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 08:43:44 +0800 Subject: [PATCH 10/16] better whitespace MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8db5c54c7beaf..27ae48814b3ad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1781,7 +1781,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_no_rope_layer_step = hparams.n_layer; } - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); @@ -5202,7 +5202,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); @@ -5554,7 +5554,7 @@ void llama_model::print_info() const { } if (arch == LLM_ARCH_SMALLTHINKER) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } From e2c900cea1b1e9ef03fec7a3232b0cb2cb79562b Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:02:14 +0800 Subject: [PATCH 11/16] use GGML_ASSERT for expert count validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 27ae48814b3ad..aae45f55b6e2b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5213,12 +5213,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for SMALLTHINKER"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for SMALLTHINKER"); - } + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp; From 594af99379c3897e0f0ed42bbe06981e0a405e1f Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:03:15 +0800 Subject: [PATCH 12/16] Improve null pointer check for probs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index aae45f55b6e2b..756bc9c60ee38 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17148,7 +17148,9 @@ struct llm_build_smallthinker : public llm_graph_context{ ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - if (probs != nullptr) { probs = ggml_get_rows(ctx0, probs, inp_out_ids); } + if (probs) { + probs = ggml_get_rows(ctx0, probs, inp_out_ids); + } } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); From 29e1fe0a6318fda58854a1e92fe259718cae6753 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Mon, 28 Jul 2025 01:21:02 +0000 Subject: [PATCH 13/16] use template parameter for SWA attention logic --- src/llama-model.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 756bc9c60ee38..e69c4ad7d58cb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17074,6 +17074,7 @@ struct llm_build_lfm2 : public llm_graph_context { } }; +template struct llm_build_smallthinker : public llm_graph_context{ llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17089,8 +17090,10 @@ struct llm_build_smallthinker : public llm_graph_context{ // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - llm_graph_input_i * inp_attn = nullptr; - if (hparams.is_swa_any()) { + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { inp_attn = build_attn_inp_kv_unified_iswa(); } else { inp_attn = build_attn_inp_kv_unified(); @@ -17134,13 +17137,9 @@ struct llm_build_smallthinker : public llm_graph_context{ cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - if (hparams.is_swa_any()) { - cur = build_attn(static_cast(inp_attn), model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, - nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); - } else { - cur = build_attn(static_cast(inp_attn), model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur, - nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il); - } + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -17630,7 +17629,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_SMALLTHINKER: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; default: GGML_ABORT("fatal error"); From 5d09d11bb879ce0968d6c952a3790f36ce64280b Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:25:23 +0800 Subject: [PATCH 14/16] better whitespace Co-authored-by: Georgi Gerganov --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e69c4ad7d58cb..53eed720b2b6e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17126,7 +17126,7 @@ struct llm_build_smallthinker : public llm_graph_context{ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - if(hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { + if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); From bb3dd583b6221e718c26cc6ffb0de8e4060d846b Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Mon, 28 Jul 2025 06:41:12 +0000 Subject: [PATCH 15/16] move the creation of inp_out_ids before the layer loop --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 53eed720b2b6e..5bf877683514b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17099,6 +17099,8 @@ struct llm_build_smallthinker : public llm_graph_context{ inp_attn = build_attn_inp_kv_unified(); } + ggml_tensor * inp_out_ids = build_inp_out_ids(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; ggml_tensor * probs = nullptr; @@ -17142,9 +17144,7 @@ struct llm_build_smallthinker : public llm_graph_context{ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); if (probs) { From e338c30c1f7a06eec774b5c547986f2b81d9a207 Mon Sep 17 00:00:00 2001 From: Dongliang Wei <2720609228@qq.com> Date: Mon, 28 Jul 2025 06:54:26 +0000 Subject: [PATCH 16/16] remove redundant judge for probs --- src/llama-model.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5bf877683514b..29cbb8113af79 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17147,9 +17147,7 @@ struct llm_build_smallthinker : public llm_graph_context{ if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - if (probs) { - probs = ggml_get_rows(ctx0, probs, inp_out_ids); - } + probs = ggml_get_rows(ctx0, probs, inp_out_ids); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);