From bca0b4827d2c947e5251aaf3d01fa9b4fe108c09 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 1 Sep 2025 19:55:51 +0200 Subject: [PATCH 1/4] add smollm3 scaffold from llama --- lib/bumblebee.ex | 1 + lib/bumblebee/text/smollm3.ex | 485 +++++++++++++++++++++++++++ test/bumblebee/text/smollm3_test.exs | 102 ++++++ 3 files changed, 588 insertions(+) create mode 100644 lib/bumblebee/text/smollm3.ex create mode 100644 test/bumblebee/text/smollm3_test.exs diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..6643b1d2 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -188,6 +188,7 @@ defmodule Bumblebee do "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, + "SmolLM3ForCausalLM" => {Bumblebee.Text.SmolLM3, :for_causal_language_modeling}, "SwinModel" => {Bumblebee.Vision.Swin, :base}, "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex new file mode 100644 index 00000000..8fdc541a --- /dev/null +++ b/lib/bumblebee/text/smollm3.ex @@ -0,0 +1,485 @@ +defmodule Bumblebee.Text.SmolLM3 do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 32000, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 1024, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 4096, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 11008, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: nil, + doc: """ + the size of the key, value, and query projection per attention head. + Defaults to `div(hidden_size, num_attention_heads)` + """ + ], + num_blocks: [ + default: 32, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: nil, + doc: "the number of key value heads for each attention layer in the model" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ], + rotary_embedding_scaling_strategy: [ + default: nil, + doc: """ + scaling configuration for rotary embedding. Currently the supported values are: + + * `%{type: :linear, factor: number()}` + + * `%{type: :dynamic, factor: number()}` + + * `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}` + + For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases + """ + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + tie_word_embeddings: [ + default: false, + doc: "whether to tie input and output embedding weights" + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + LLaMA model family. + + ## Architectures + + * `:base` - plain LLaMA without any head on top + + * `:for_causal_language_modeling` - LLaMA with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - LLaMA with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + attention_head_size: spec.attention_head_size, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal: true, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + scaling_strategy_converter = fn name, value -> + # "type" has been renamed to "rope_type" + value = + case Map.pop(value, "type") do + {nil, value} -> value + {type, value} -> Map.put(value, "rope_type", type) + end + + case value do + %{"rope_type" => "linear", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :linear, factor: factor}} + + %{"rope_type" => "dynamic", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :dynamic, factor: factor}} + + %{ + "rope_type" => "llama3", + "factor" => factor, + "low_freq_factor" => low_frequency_factor, + "high_freq_factor" => high_frequency_factor, + "original_max_position_embeddings" => original_max_positions + } + when is_number(factor) and is_number(low_frequency_factor) and + is_number(high_frequency_factor) and + is_number(original_max_positions) -> + {:ok, + %{ + type: :llama3, + factor: factor, + low_frequency_factor: low_frequency_factor, + high_frequency_factor: high_frequency_factor, + original_max_positions: original_max_positions + }} + + _other -> + {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} + end + end + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_scaling_strategy: + {"rope_scaling", optional(scaling_strategy_converter)}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.self_attention.rotary_embedding" => + "model.layers.{n}.self_attn.rotary_emb", + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", + "output_norm" => "model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/test/bumblebee/text/smollm3_test.exs b/test/bumblebee/text/smollm3_test.exs new file mode 100644 index 00000000..ac53aa21 --- /dev/null +++ b/test/bumblebee/text/smollm3_test.exs @@ -0,0 +1,102 @@ +defmodule Bumblebee.Text.SmolLM3Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaModel"}) + + assert %Bumblebee.Text.Llama{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[1.4799, -2.0333, 0.4759], [2.3749, -0.8369, -0.0206], [0.5767, -0.0515, -1.1795]] + ]) + ) + end + + test ":base rotary embedding scaling strategy :llama3" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, + "bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"} + ) + + assert %Bumblebee.Text.Llama{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]] + ]) + ) + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-LlamaForSequenceClassification"} + ) + + assert %Bumblebee.Text.Llama{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[-0.1964, -0.1069]]) + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaForCausalLM"}) + + assert %Bumblebee.Text.Llama{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.0469, -0.0751, 0.0349], [0.0617, -0.1357, -0.0204], [-0.1495, 0.0557, -0.0737]] + ]) + ) + end +end From bc5e0e1bfad9359641d9112ddc6e24a2e6303630 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 3 Oct 2025 19:36:53 +0200 Subject: [PATCH 2/4] implement smollm3 --- lib/bumblebee.ex | 5 + lib/bumblebee/layers.ex | 20 +++ lib/bumblebee/layers/transformer.ex | 17 +- lib/bumblebee/text/pre_trained_tokenizer.ex | 6 + lib/bumblebee/text/smollm3.ex | 165 +++++++++++++++++--- test/bumblebee/text/smollm3_test.exs | 76 ++++++--- 6 files changed, 248 insertions(+), 41 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 6643b1d2..39ef9165 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -188,7 +188,11 @@ defmodule Bumblebee do "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, + "SmolLM3" => {Bumblebee.Text.SmolLM3, :base}, "SmolLM3ForCausalLM" => {Bumblebee.Text.SmolLM3, :for_causal_language_modeling}, + "SmolLM3ForQuestionAnswering" => {Bumblebee.Text.SmolLM3, :for_question_answering}, + "SmolLM3ForSequenceClassification" => {Bumblebee.Text.SmolLM3, :for_sequence_classification}, + "SmolLM3ForTokenClassification" => {Bumblebee.Text.SmolLM3, :for_token_classification}, "SwinModel" => {Bumblebee.Vision.Swin, :base}, "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, @@ -255,6 +259,7 @@ defmodule Bumblebee do "phi" => :code_gen, "phi3" => :llama, "roberta" => :roberta, + "smollm3" => :smollm3, "t5" => :t5, "whisper" => :whisper, "xlm-roberta" => :xlm_roberta, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 3c9c92ca..559e65ad 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1307,6 +1307,26 @@ defmodule Bumblebee.Layers do positions_cos_sin(position, inv_frequency) + %{ + type: :yarn, + factor: factor, + low_frequency_factor: low_frequency_factor, + high_frequency_factor: high_frequency_factor, + original_max_positions: original_max_positions + } -> + inv_frequency = inv_frequency(base, range) + + inv_frequency = + llama3_inv_frequency( + inv_frequency, + factor, + low_frequency_factor, + high_frequency_factor, + original_max_positions + ) + + positions_cos_sin(position, inv_frequency) + _other -> inv_frequency = inv_frequency(base, range) positions_cos_sin(position, inv_frequency) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 6cf93cd6..59ad9595 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -21,6 +21,10 @@ defmodule Bumblebee.Layers.Transformer do is configured, this option controls whether the bias from the first block is used for all other blocks. Defaults to `false` + * `:rotary_embedding` - configuration of rotary embedding. Can be: + - a keyword list (applied to all blocks) + - a function that takes the block index and returns the configuration + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -49,8 +53,7 @@ defmodule Bumblebee.Layers.Transformer do :layer_norm, :block_type, :attention_window_size, - :scale_attention_weights, - :rotary_embedding + :scale_attention_weights ] opts = @@ -60,6 +63,7 @@ defmodule Bumblebee.Layers.Transformer do [ :name, :num_blocks, + :rotary_embedding, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -80,6 +84,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_mask = opts[:cross_attention_mask] cross_attention_head_mask = opts[:cross_attention_head_mask] cache = opts[:cache] + rotary_embedding = opts[:rotary_embedding] block_opts = Keyword.take(opts, block_opts_keys) @@ -109,6 +114,13 @@ defmodule Bumblebee.Layers.Transformer do opts[:attention_relative_bias] || Layers.none() end + block_rotary_embedding = + case rotary_embedding do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + config when is_list(config) -> config + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, @@ -121,6 +133,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_head_mask: block_cross_attention_head_mask, block_cache: block_cache, offset: offset, + rotary_embedding: block_rotary_embedding, name: join(name, idx) ] ++ block_opts ) diff --git a/lib/bumblebee/text/pre_trained_tokenizer.ex b/lib/bumblebee/text/pre_trained_tokenizer.ex index 59ab3468..599ac647 100644 --- a/lib/bumblebee/text/pre_trained_tokenizer.ex +++ b/lib/bumblebee/text/pre_trained_tokenizer.ex @@ -211,6 +211,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do mask: "" } }, + smollm3: %{ + special_tokens: %{ + eos: "<|im_end|>", + pad: "<|im_end|>" + } + }, t5: %{ special_tokens: %{ bos: "", diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex index 8fdc541a..6c86de59 100644 --- a/lib/bumblebee/text/smollm3.ex +++ b/lib/bumblebee/text/smollm3.ex @@ -4,18 +4,19 @@ defmodule Bumblebee.Text.SmolLM3 do options = [ vocab_size: [ - default: 32000, + default: 128_256, doc: """ the vocabulary size of the token embedding. This corresponds to the number of distinct tokens that can be represented in model input and output """ ], max_positions: [ - default: 1024, + default: 65536, doc: """ the vocabulary size of the position embedding. This corresponds to the maximum sequence length that this model can process. Typically this is set to a large value just in case, - such as 512, 1024 or 2048 + such as 512, 1024 or 2048. + SmolLM3 supports up to 128k tokens with YaRN extrapolation. """ ], hidden_size: [ @@ -42,7 +43,7 @@ defmodule Bumblebee.Text.SmolLM3 do doc: "the number of attention heads for each attention layer in the model" ], num_key_value_heads: [ - default: nil, + default: 4, doc: "the number of key value heads for each attention layer in the model" ], activation: [ @@ -50,7 +51,7 @@ defmodule Bumblebee.Text.SmolLM3 do doc: "the activation function" ], rotary_embedding_base: [ - default: 10_000, + default: 5_000_000, doc: "base for computing rotary embedding frequency" ], rotary_embedding_scaling_strategy: [ @@ -64,6 +65,8 @@ defmodule Bumblebee.Text.SmolLM3 do * `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}` + * `%{type: :yarn, factor: number(), original_max_positions: pos_integer()}` + For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases """ ], @@ -77,27 +80,45 @@ defmodule Bumblebee.Text.SmolLM3 do "the standard deviation of the normal initializer used for initializing kernel parameters" ], tie_word_embeddings: [ - default: false, + default: true, doc: "whether to tie input and output embedding weights" ] ] ++ Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) @moduledoc """ - LLaMA model family. + SmolLM3 is a 3B parameter language model designed to push the boundaries of small models. + It supports dual mode reasoning, 6 languages and long context. SmolLM3 is a fully open model + that offers strong performance at the 3B–4B scale. + + Key features + + * Instruct model optimized for hybrid reasoning + * Fully open model: open weights + full training details including public data mixture and training configs + * Long context: Trained on 64k context and supports up to 128k tokens using YARN extrapolation + * Multilingual: 6 natively supported (English, French, Spanish, German, Italian, and Portuguese) + + For more details see: https://huggingface.co/HuggingFaceTB/SmolLM3-3B ## Architectures - * `:base` - plain LLaMA without any head on top + * `:base` - plain SmolLM3 without any head on top - * `:for_causal_language_modeling` - LLaMA with a language modeling + * `:for_causal_language_modeling` - SmolLM3 with a language modeling head. The head returns logits for each token in the original sequence - * `:for_sequence_classification` - LLaMA with a sequence + * `:for_sequence_classification` - SmolLM3 with a sequence classification head. The head returns logits corresponding to possible classes + * `:for_token_classification` - SmolLM3 with a token classification + head. The head returns logits for each token in the original + sequence + + * `:for_question_answering` - SmolLM3 with a span classification head. + The head returns logits for the span start and end positions + ## Inputs * `"input_ids"` - `{batch_size, sequence_length}` @@ -159,7 +180,9 @@ defmodule Bumblebee.Text.SmolLM3 do do: [ :base, :for_causal_language_modeling, - :for_sequence_classification + :for_sequence_classification, + :for_token_classification, + :for_question_answering ] @impl true @@ -253,6 +276,50 @@ defmodule Bumblebee.Text.SmolLM3 do }) end + def model(%__MODULE__{architecture: :for_token_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Axon.dropout( + rate: 0.1, + name: "token_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "token_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_question_answering} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, 2, + kernel_initializer: kernel_initializer(spec), + name: "question_answering_head.output" + ) + + {start_logits, end_logits} = Layers.split_pair(logits) + + Layers.output(%{ + start_logits: start_logits, + end_logits: end_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + defp inputs(spec) do shape = {nil, nil} hidden_shape = {nil, nil, spec.hidden_size} @@ -330,6 +397,22 @@ defmodule Bumblebee.Text.SmolLM3 do ) do name = opts[:name] + # TODO: remove hardcoding of 4th layers, read from config + rotary_embedding_config = [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ] + + nope_rotary_embedding = fn layer_idx -> + if rem(layer_idx + 1, 4) != 0 do + rotary_embedding_config + else + nil + end + end + Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, attention_head_mask: attention_head_mask, @@ -348,12 +431,7 @@ defmodule Bumblebee.Text.SmolLM3 do ), block_type: :norm_first, causal: true, - rotary_embedding: [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base, - scaling_strategy: spec.rotary_embedding_scaling_strategy - ], + rotary_embedding: nope_rotary_embedding, query_use_bias: false, key_use_bias: false, value_use_bias: false, @@ -431,6 +509,20 @@ defmodule Bumblebee.Text.SmolLM3 do original_max_positions: original_max_positions }} + # TODO: implement yarn or find out if it's same as longrope + %{ + "rope_type" => "yarn", + "factor" => factor, + "original_max_position_embeddings" => original_max_positions + } + when is_number(factor) and is_number(original_max_positions) -> + {:ok, + %{ + type: :yarn, + factor: factor, + original_max_positions: original_max_positions + }} + _other -> {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} end @@ -462,15 +554,13 @@ defmodule Bumblebee.Text.SmolLM3 do defimpl Bumblebee.HuggingFace.Transformers.Model do def params_mapping(spec) do - %{ + base_mapping = %{ "embedder.token_embedding" => "model.embed_tokens", "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", - "decoder.blocks.{n}.self_attention.rotary_embedding" => - "model.layers.{n}.self_attn.rotary_emb", "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", @@ -478,8 +568,41 @@ defmodule Bumblebee.Text.SmolLM3 do "output_norm" => "model.norm", "language_modeling_head.output" => if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), - "sequence_classification_head.output" => "score" + "sequence_classification_head.output" => "score", + "token_classification_head.output" => "score", + "question_answering_head.output" => "qa_outputs" } + + # TODO: remove hardcoding, read from config + rotary_mapping = + for n <- 0..(spec.num_blocks - 1), rem(n + 1, 4) != 0 do + {"decoder.blocks.#{n}.self_attention.rotary_embedding", + "model.layers.#{n}.self_attn.rotary_emb"} + end + + mapping = Map.merge(base_mapping, Map.new(rotary_mapping)) + + case spec do + %{architecture: :for_question_answering} -> + question_answering_mapping = %{ + "output_norm" => "transformer.norm", + "embedder.token_embedding" => "transformer.embed_tokens", + "decoder.blocks.0.output_norm" => "transformer.layers.0.post_attention_layernorm", + "decoder.blocks.0.self_attention.key" => "transformer.layers.0.self_attn.k_proj", + "decoder.blocks.0.self_attention.query" => "transformer.layers.0.self_attn.q_proj", + "decoder.blocks.0.self_attention.value" => "transformer.layers.0.self_attn.v_proj", + "decoder.blocks.0.self_attention_norm" => "transformer.layers.0.input_layernorm", + "decoder.blocks.0.self_attention.output" => "transformer.layers.0.self_attn.o_proj", + "decoder.blocks.0.ffn.output" => "transformer.layers.0.mlp.down_proj", + "decoder.blocks.0.ffn.intermediate" => "transformer.layers.0.mlp.up_proj", + "decoder.blocks.0.ffn.gate" => "transformer.layers.0.mlp.gate_proj" + } + + Map.merge(mapping, question_answering_mapping) + + _else -> + mapping + end end end end diff --git a/test/bumblebee/text/smollm3_test.exs b/test/bumblebee/text/smollm3_test.exs index ac53aa21..a3f70c51 100644 --- a/test/bumblebee/text/smollm3_test.exs +++ b/test/bumblebee/text/smollm3_test.exs @@ -7,9 +7,9 @@ defmodule Bumblebee.Text.SmolLM3Test do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaModel"}) + Bumblebee.load_model({:hf, "joelkoch/tiny_random_smollm3"}, architecture: :base) - assert %Bumblebee.Text.Llama{architecture: :base} = spec + assert %Bumblebee.Text.SmolLM3{architecture: :base} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -18,24 +18,24 @@ defmodule Bumblebee.Text.SmolLM3Test do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + assert Nx.shape(outputs.hidden_state) == {1, 10, 64} assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[1.4799, -2.0333, 0.4759], [2.3749, -0.8369, -0.0206], [0.5767, -0.0515, -1.1795]] + [[-0.4167, -0.0137, 0.7160], [-0.2624, -1.1185, -0.3098], [-0.0383, -0.8390, -0.0039]] ]) ) end - test ":base rotary embedding scaling strategy :llama3" do + test ":for_question_answering" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( - {:hf, - "bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"} + {:hf, "joelkoch/tiny_smollm3_for_question_answering"}, + architecture: :for_question_answering ) - assert %Bumblebee.Text.Llama{architecture: :base} = spec + assert %Bumblebee.Text.SmolLM3{architecture: :for_question_answering} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -44,12 +44,12 @@ defmodule Bumblebee.Text.SmolLM3Test do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + assert Nx.shape(outputs.end_logits) == {1, 10} assert_all_close( - outputs.hidden_state[[.., 1..3, 1..3]], + outputs.end_logits, Nx.tensor([ - [[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]] + [0.0656, 0.0358, -0.0395, 0.0227, 0.0594, 0.0942, -0.2356, 0.0244, 0.0701, 0.0705] ]) ) end @@ -57,10 +57,11 @@ defmodule Bumblebee.Text.SmolLM3Test do test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( - {:hf, "bumblebee-testing/tiny-random-LlamaForSequenceClassification"} + {:hf, "joelkoch/tiny_smollm3_for_sequence_classification"}, + architecture: :for_sequence_classification ) - assert %Bumblebee.Text.Llama{architecture: :for_sequence_classification} = spec + assert %Bumblebee.Text.SmolLM3{architecture: :for_sequence_classification} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -73,15 +74,17 @@ defmodule Bumblebee.Text.SmolLM3Test do assert_all_close( outputs.logits, - Nx.tensor([[-0.1964, -0.1069]]) + Nx.tensor([[0.1468, -0.0980]]) ) end test ":for_causal_language_modeling" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaForCausalLM"}) + Bumblebee.load_model({:hf, "joelkoch/tiny_random_smollm3"}, + architecture: :for_causal_language_modeling + ) - assert %Bumblebee.Text.Llama{architecture: :for_causal_language_modeling} = spec + assert %Bumblebee.Text.SmolLM3{architecture: :for_causal_language_modeling} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -90,12 +93,49 @@ defmodule Bumblebee.Text.SmolLM3Test do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.logits) == {1, 10, 1024} + assert Nx.shape(outputs.logits) == {1, 10, 128_256} assert_all_close( outputs.logits[[.., 1..3, 1..3]], Nx.tensor([ - [[0.0469, -0.0751, 0.0349], [0.0617, -0.1357, -0.0204], [-0.1495, 0.0557, -0.0737]] + [[0.0602, 0.1254, 0.0077], [0.0187, 0.0270, 0.0625], [-0.0079, -0.0478, 0.1872]] + ]) + ) + end + + test ":for_token_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "joelkoch/tiny_smollm3_for_token_classification"}, + architecture: :for_token_classification + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :for_token_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([ + [ + [-0.1358, 0.1047], + [-0.0504, 0.1214], + [0.1960, -0.0031], + [0.0428, 0.0429], + [-0.0680, 0.1391], + [0.0828, 0.0945], + [-0.0144, -0.2466], + [0.0152, 0.1096], + [0.1437, -0.1766], + [0.1439, -0.1762] + ] ]) ) end From 2177e11df759788437f6734f6a0713e2c95407f5 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 3 Oct 2025 19:54:49 +0200 Subject: [PATCH 3/4] get nope layers config from config --- lib/bumblebee/text/smollm3.ex | 40 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex index 6c86de59..98754b57 100644 --- a/lib/bumblebee/text/smollm3.ex +++ b/lib/bumblebee/text/smollm3.ex @@ -70,6 +70,11 @@ defmodule Bumblebee.Text.SmolLM3 do For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases """ ], + no_rope_layers: [ + default: nil, + doc: + "a list containing 0 or 1 at the corresponding index for each layer. 0 means no rope layer, 1 means rope layer." + ], layer_norm_epsilon: [ default: 1.0e-12, doc: "the epsilon used by RMS normalization layers" @@ -397,7 +402,6 @@ defmodule Bumblebee.Text.SmolLM3 do ) do name = opts[:name] - # TODO: remove hardcoding of 4th layers, read from config rotary_embedding_config = [ position_ids: position_ids, max_positions: spec.max_positions, @@ -405,13 +409,20 @@ defmodule Bumblebee.Text.SmolLM3 do scaling_strategy: spec.rotary_embedding_scaling_strategy ] - nope_rotary_embedding = fn layer_idx -> - if rem(layer_idx + 1, 4) != 0 do - rotary_embedding_config - else - nil + nope_rotary_embedding = + case opts[:no_rope_layers] do + nil -> + rotary_embedding_config + + no_rope_layers -> + fn layer_index -> + if Enum.at(no_rope_layers, layer_index) == 1 do + rotary_embedding_config + else + nil + end + end end - end Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, @@ -573,11 +584,18 @@ defmodule Bumblebee.Text.SmolLM3 do "question_answering_head.output" => "qa_outputs" } - # TODO: remove hardcoding, read from config rotary_mapping = - for n <- 0..(spec.num_blocks - 1), rem(n + 1, 4) != 0 do - {"decoder.blocks.#{n}.self_attention.rotary_embedding", - "model.layers.#{n}.self_attn.rotary_emb"} + case spec.no_rope_layers do + nil -> + [] + + no_rope_layers -> + Enum.with_index(no_rope_layers, fn rope, index -> + if rope == 1 do + {"decoder.blocks.#{index}.self_attention.rotary_embedding", + "model.layers.#{index}.self_attn.rotary_emb"} + end + end) end mapping = Map.merge(base_mapping, Map.new(rotary_mapping)) From 177238f8b24cdd399bc262b6ea2ef0d7693cbdc2 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 3 Oct 2025 20:12:36 +0200 Subject: [PATCH 4/4] don't implement yarn --- lib/bumblebee/layers.ex | 20 -------------------- lib/bumblebee/text/smollm3.ex | 19 +------------------ 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 559e65ad..3c9c92ca 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1307,26 +1307,6 @@ defmodule Bumblebee.Layers do positions_cos_sin(position, inv_frequency) - %{ - type: :yarn, - factor: factor, - low_frequency_factor: low_frequency_factor, - high_frequency_factor: high_frequency_factor, - original_max_positions: original_max_positions - } -> - inv_frequency = inv_frequency(base, range) - - inv_frequency = - llama3_inv_frequency( - inv_frequency, - factor, - low_frequency_factor, - high_frequency_factor, - original_max_positions - ) - - positions_cos_sin(position, inv_frequency) - _other -> inv_frequency = inv_frequency(base, range) positions_cos_sin(position, inv_frequency) diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex index 98754b57..30779917 100644 --- a/lib/bumblebee/text/smollm3.ex +++ b/lib/bumblebee/text/smollm3.ex @@ -16,7 +16,6 @@ defmodule Bumblebee.Text.SmolLM3 do the vocabulary size of the position embedding. This corresponds to the maximum sequence length that this model can process. Typically this is set to a large value just in case, such as 512, 1024 or 2048. - SmolLM3 supports up to 128k tokens with YaRN extrapolation. """ ], hidden_size: [ @@ -65,8 +64,6 @@ defmodule Bumblebee.Text.SmolLM3 do * `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}` - * `%{type: :yarn, factor: number(), original_max_positions: pos_integer()}` - For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases """ ], @@ -100,7 +97,7 @@ defmodule Bumblebee.Text.SmolLM3 do * Instruct model optimized for hybrid reasoning * Fully open model: open weights + full training details including public data mixture and training configs - * Long context: Trained on 64k context and supports up to 128k tokens using YARN extrapolation + * Long context: Trained on 64k context and supports up to 128k tokens using YARN extrapolation (not implemented in `bumblebee`) * Multilingual: 6 natively supported (English, French, Spanish, German, Italian, and Portuguese) For more details see: https://huggingface.co/HuggingFaceTB/SmolLM3-3B @@ -520,20 +517,6 @@ defmodule Bumblebee.Text.SmolLM3 do original_max_positions: original_max_positions }} - # TODO: implement yarn or find out if it's same as longrope - %{ - "rope_type" => "yarn", - "factor" => factor, - "original_max_position_embeddings" => original_max_positions - } - when is_number(factor) and is_number(original_max_positions) -> - {:ok, - %{ - type: :yarn, - factor: factor, - original_max_positions: original_max_positions - }} - _other -> {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} end