Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
Expand Down Expand Up @@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand Down
39 changes: 29 additions & 10 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
Expand All @@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
uint3 ntg[[threads_per_threadgroup]]) {

const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
Expand All @@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;

const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;

device const int32_t * ids = (device const int32_t *) src6;

Expand All @@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);

threadgroup_barrier(mem_flags::mem_threadgroup);

float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
sumf += state * C[tpitg.x];
s[i] = state;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer.

I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible?

Unless I'm misunderstanding the memory model, each thread only handles a single state (as in s[i] always refers to the same place, but differs between threads).

I think this would only affect prompt processing speed, not really small batches, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting thought! I'm also very much still learning the memory model, so I'll play with this idea and see how far I can get it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. It was easy to implement and gives a nice bump in performance. Will commit and push shortly

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.312 821.40 5.700 44.91 6.012 85.17
256 256 2 1024 0.582 880.34 9.658 53.01 10.240 100.00
2560 256 1 2816 2.845 899.88 5.854 43.73 8.699 323.71
2560 256 2 5632 5.700 898.23 9.911 51.66 15.611 360.76


sumf = simd_sum(sumf);

threadgroup_barrier(mem_flags::mem_threadgroup);

// Use the shared buffer to hold the sum of each simd group
if (tiisg == 0) {
shared[sgitg] = sumf;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// Sum the simd buckets
sumf = shared[tiisg];
sumf = simd_sum(sumf);

y[0] = sumf;

// recurse
Expand Down
Loading