Skip to content

Conversation

joelpaulkoch
Copy link
Contributor

Hey, this is the SmolLM3 model from huggingface. It's smol, fully open and supports reasoning, so I figured it would be a nice addition to bumblebee.

I didn't implement YaRN extrapolation.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @joelpaulkoch, this looks great! I dropped a few small comments and it's good to go :)

Comment on lines +588 to +602
question_answering_mapping = %{
"output_norm" => "transformer.norm",
"embedder.token_embedding" => "transformer.embed_tokens",
"decoder.blocks.0.output_norm" => "transformer.layers.0.post_attention_layernorm",
"decoder.blocks.0.self_attention.key" => "transformer.layers.0.self_attn.k_proj",
"decoder.blocks.0.self_attention.query" => "transformer.layers.0.self_attn.q_proj",
"decoder.blocks.0.self_attention.value" => "transformer.layers.0.self_attn.v_proj",
"decoder.blocks.0.self_attention_norm" => "transformer.layers.0.input_layernorm",
"decoder.blocks.0.self_attention.output" => "transformer.layers.0.self_attn.o_proj",
"decoder.blocks.0.ffn.output" => "transformer.layers.0.mlp.down_proj",
"decoder.blocks.0.ffn.intermediate" => "transformer.layers.0.mlp.up_proj",
"decoder.blocks.0.ffn.gate" => "transformer.layers.0.mlp.gate_proj"
}

Map.merge(mapping, question_answering_mapping)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use a different prefix for all layers, so we can probably just do this:

Suggested change
question_answering_mapping = %{
"output_norm" => "transformer.norm",
"embedder.token_embedding" => "transformer.embed_tokens",
"decoder.blocks.0.output_norm" => "transformer.layers.0.post_attention_layernorm",
"decoder.blocks.0.self_attention.key" => "transformer.layers.0.self_attn.k_proj",
"decoder.blocks.0.self_attention.query" => "transformer.layers.0.self_attn.q_proj",
"decoder.blocks.0.self_attention.value" => "transformer.layers.0.self_attn.v_proj",
"decoder.blocks.0.self_attention_norm" => "transformer.layers.0.input_layernorm",
"decoder.blocks.0.self_attention.output" => "transformer.layers.0.self_attn.o_proj",
"decoder.blocks.0.ffn.output" => "transformer.layers.0.mlp.down_proj",
"decoder.blocks.0.ffn.intermediate" => "transformer.layers.0.mlp.up_proj",
"decoder.blocks.0.ffn.gate" => "transformer.layers.0.mlp.gate_proj"
}
Map.merge(mapping, question_answering_mapping)
for {key, value} <- mapping, into: %{} do
{key, String.replace_leading(value, "model.", "transformer.")}
end

Comment on lines +25 to +27
Nx.tensor([
[[-0.4167, -0.0137, 0.7160], [-0.2624, -1.1185, -0.3098], [-0.0383, -0.8390, -0.0039]]
])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double-checking, these values come from Python, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, coming from Python :) although the repo config is so tiny, it's not even hitting the no rope layer case.

As a sidenote, I think next time I'll try to set up a simple validation script with pythonx so that it can be reused for contributing model implementations.

For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
no_rope_layers: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is very confusing, initially I thought it means not-RoPE, but 1 (true) actually enables RoPE. So I guess it rather means No- and Ro-PE.

One alternative configuration I can think of would be :rotary_embedding_enabled, with a list of booleans true/false (and if omitted, defaults to true). We can easily convert the representation when loading the config. What do you think?

On a sidenote, we generally use "block" wherever they say "layer" (because it is a group of whole layers).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the naming is very confusing, took it directly from huggingface to see what you are going to suggest, sorry. Also, very confusing that they have no_rope_layers and no_rope_layer_interval.

:rotary_embedding_enabled sounds good to me 👍

Comment on lines +214 to +219
smollm3: %{
special_tokens: %{
eos: "<|im_end|>",
pad: "<|im_end|>"
}
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joelpaulkoch
Copy link
Contributor Author

The implementation is basically llama + NoPE support (in the transformer block) + architectures that are supported but missing in llama (i.e. :for_question_answering and :for_token_classification). So, would you prefer to add the optional NoPE support and architectures in the llama implementation and map smollm3 to llama?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants