From ba74a2473045deda1be08a2cac42b8731644a6a8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 15 Jul 2025 17:03:38 -0600 Subject: [PATCH 01/12] feat: Add s_off as a parameter in the args struct This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 752d55c216604..6f4427c55c50e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -519,6 +519,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; From 8d5a25d3562617a95d800645c38d1b7b746bff5a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 15 Jul 2025 17:04:31 -0600 Subject: [PATCH 02/12] perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 4 ++- ggml/src/ggml-metal/ggml-metal.metal | 39 +++++++++++++++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 44ddc69d08f1c..de7d33046fc23 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node( if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 13235e2885241..ac2895b5164f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; @@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group( const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); + + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const int64_t i = tpitg.x + i1*nc; + const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); + sumf += state * C[tpitg.x]; + s[i] = state; + + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use the shared buffer to hold the sum of each simd group + if (tiisg == 0) { + shared[sgitg] = sumf; } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Sum the simd buckets + sumf = shared[tiisg]; + sumf = simd_sum(sumf); + y[0] = sumf; // recurse From e16e24bebdddae8c939e1d80fc576eca014b2fc9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 18 Jul 2025 10:49:06 -0600 Subject: [PATCH 03/12] fix: Update logic to correctly do the multi-layer parallel sum Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ac2895b5164f6..27005a053db34 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group( uint3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head @@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group( const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - threadgroup_barrier(mem_flags::mem_threadgroup); - - float sumf = 0.0f; - const int64_t i = tpitg.x + i1*nc; const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); - sumf += state * C[tpitg.x]; s[i] = state; - sumf = simd_sum(sumf); + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. - threadgroup_barrier(mem_flags::mem_threadgroup); + // Computed for each thread + float sumf = state * C[tpitg.x]; - // Use the shared buffer to hold the sum of each simd group + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + // Once per simd group, place the group sum into the shared buffer if (tiisg == 0) { shared[sgitg] = sumf; } + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. threadgroup_barrier(mem_flags::mem_threadgroup); - // Sum the simd buckets - sumf = shared[tiisg]; - sumf = simd_sum(sumf); + // Sum the simd buckets => threadgroup sum + sumf = 0.0f; + for (int64_t i0 = 0; i0 < sgptg; ++i0) { + sumf += shared[i0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); y[0] = sumf; From 21db0b598af4c4005ee89e9e95d802e57a92c767 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 18 Jul 2025 11:05:31 -0600 Subject: [PATCH 04/12] fix: Correctly size the shared memory bufer and assert expected size relationships Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index de7d33046fc23..d515ec0a32616 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3017,7 +3017,15 @@ static bool ggml_metal_encode_node( if (ne30 == 1) { // Mamba-2 - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size + + // One shared memory bucket for each simd group in the threadgroup + const int64_t shmem_size = d_state / 32; + GGML_ASSERT(shmem_size * 32 == d_state); + + // One thread pre element in d_state + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); From a5334f911e095b7e4df2de497f626953080722b8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 18 Jul 2025 13:55:51 -0600 Subject: [PATCH 05/12] refactor: Compute block offsets once rather than once per token Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 27005a053db34..c50236ba49fc2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1791,13 +1791,19 @@ kernel void kernel_ssm_scan_f32_group( device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 3866f766fe8e508354013ea9e10a5c87b31e7681 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 21 Jul 2025 09:31:43 -0600 Subject: [PATCH 06/12] feat: Use local variable for state recursion Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c50236ba49fc2..0397cd9b531c3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1788,8 +1788,11 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = tpitg.x + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); @@ -1809,9 +1812,8 @@ kernel void kernel_ssm_scan_f32_group( const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - const int64_t i = tpitg.x + i1*nc; - const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); - s[i] = state; + const float state = (s0 * dA) + (B[tpitg.x] * x_dt); + s = state; // Parallel sum: This relies on the fact that this kernel will be // dispatched with each threadgroup having (d_state, 1, 1) threads which @@ -1851,6 +1853,9 @@ kernel void kernel_ssm_scan_f32_group( // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } kernel void kernel_rwkv_wkv6_f32( From 641276a8162295ccf750dac08f1a993c7b74b06a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 22 Jul 2025 09:13:45 -0600 Subject: [PATCH 07/12] feat: Use a secondary simd_sum instead of a for loop Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0397cd9b531c3..b7c474de861fa 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1840,16 +1840,19 @@ kernel void kernel_ssm_scan_f32_group( // sum of the individual simd groups. threadgroup_barrier(mem_flags::mem_threadgroup); - // Sum the simd buckets => threadgroup sum + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum sumf = 0.0f; - for (int64_t i0 = 0; i0 < sgptg; ++i0) { - sumf += shared[i0]; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - y[0] = sumf; - // recurse s0 = s; } From d06d08769cfa39a9d49608a3fd2c9a5acaa1d07d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 22 Jul 2025 09:22:08 -0600 Subject: [PATCH 08/12] feat: Add assertion and comment about relationship between simd size and num simd groups Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d515ec0a32616..25d43122e948b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3022,6 +3022,13 @@ static bool ggml_metal_encode_node( const int64_t shmem_size = d_state / 32; GGML_ASSERT(shmem_size * 32 == d_state); + // The final simd_sum won't work if the number of simd groups is + // larger than the size of a single simd group. If this case is + // hit at some point, the logic in the second simd_sum could be + // expanded to handle this with one more sequential simd_sum to + // collapse simd group sums another time. + GGML_ASSERT(shmem_size <= 32); + // One thread pre element in d_state GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); From 80545ef568ddc7b946603a04efc6078b467eb172 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 22 Jul 2025 11:00:19 -0600 Subject: [PATCH 09/12] feat: Parallelize of d_state for mamba-1 Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 13 ++-- ggml/src/ggml-metal/ggml-metal.metal | 101 +++++++++++++++++++++------ 2 files changed, 85 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 25d43122e948b..9c3bba5f3e1c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3015,12 +3015,9 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; [encoder setBytes:&args length:sizeof(args) atIndex:8]; - if (ne30 == 1) { - // Mamba-2 - - // One shared memory bucket for each simd group in the threadgroup + // One shared memory bucket for each simd group in the threadgroup + if (d_state >= 32) { const int64_t shmem_size = d_state / 32; - GGML_ASSERT(shmem_size * 32 == d_state); // The final simd_sum won't work if the number of simd groups is // larger than the size of a single simd group. If this case is @@ -3033,10 +3030,14 @@ static bool ggml_metal_encode_node( GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + + if (ne30 == 1) { + // Mamba-2 [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } } break; case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b7c474de861fa..4ffa56d45b1d3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1700,10 +1700,16 @@ kernel void kernel_ssm_scan_f32( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = 0; const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq @@ -1718,37 +1724,85 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; - } + const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; + + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); - y[0] = sumf; + if (sgptg > 1) { + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; + } + + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } + } else if (tiisg == 0) { + y[0] = sumf; + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part @@ -1770,6 +1824,7 @@ kernel void kernel_ssm_scan_f32_group( ushort sgptg[[simdgroups_per_threadgroup]], uint3 tgpg[[threadgroups_per_grid]]) { + const int64_t i0 = tpitg.x; const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1790,7 +1845,7 @@ kernel void kernel_ssm_scan_f32_group( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = tpitg.x + i1*nc; + const int64_t i = i0 + i1*nc; float s0 = s0_buff[i]; float s = s_buff[i]; @@ -1812,7 +1867,7 @@ kernel void kernel_ssm_scan_f32_group( const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - const float state = (s0 * dA) + (B[tpitg.x] * x_dt); + const float state = (s0 * dA) + (B[i0] * x_dt); s = state; // Parallel sum: This relies on the fact that this kernel will be @@ -1825,7 +1880,7 @@ kernel void kernel_ssm_scan_f32_group( // over the individual group sums to compute the final sum. // Computed for each thread - float sumf = state * C[tpitg.x]; + float sumf = state * C[i0]; // Sum the threads in the simd group => simd sum sumf = simd_sum(sumf); From 16bc059660c1c59e566628201c0ca2c20c9f4bc3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 22 Jul 2025 11:37:43 -0600 Subject: [PATCH 10/12] feat: Parallel sum in SSM_CONV Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 21 ++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 40 ++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9c3bba5f3e1c9..51ea6d217b10f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:3]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + const int64_t d_state = ne10; + + // One shared memory bucket for each simd group in the threadgroup + if (d_state >= 32) { + const int64_t shmem_size = d_state / 32; + + // The final simd_sum won't work if the number of simd groups is + // larger than the size of a single simd group. If this case is + // hit at some point, the logic in the second simd_sum could be + // expanded to handle this with one more sequential simd_sum to + // collapse simd group sums another time. + GGML_ASSERT(shmem_size <= 32); + + // One thread pre element in d_state + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } break; case GGML_OP_SSM_SCAN: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4ffa56d45b1d3..c6871457908d4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32( device const void * src0, device const void * src1, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_conv & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t ir = tgpig.x; const int64_t i2 = tgpig.y; const int64_t i3 = tgpig.z; @@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32( device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); - float sumf = 0.0f; + float sumf = s[i0] * c[i0]; - for (int64_t i0 = 0; i0 < nc; ++i0) { - sumf += s[i0] * c[i0]; - } + // Parallel sum: first sum over threads in simd group, then sum over simd + // group sums + sumf = simd_sum(sumf); - x[0] = sumf; + // If multiple simd groups per threadgroup, sum over simd group sums + if (sgptg > 1) { + if (tiisg == 0) { + shared[sgitg] = sumf; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + x[0] = sumf; + } + } + } else if (tiisg == 0) { + x[0] = sumf; + } } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part From e55176a0dddefdeacd82df605743a1fbbd2488bd Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 23 Jul 2025 06:54:30 -0600 Subject: [PATCH 11/12] Revert "feat: Parallel sum in SSM_CONV" After discussion with @compilade, the size of the parallelism here is not worth the cost in complexity or overhead of the parallel for. https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357 This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3. Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 21 +-------------- ggml/src/ggml-metal/ggml-metal.metal | 40 ++++++---------------------- 2 files changed, 9 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 51ea6d217b10f..9c3bba5f3e1c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2909,26 +2909,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:3]; - const int64_t d_state = ne10; - - // One shared memory bucket for each simd group in the threadgroup - if (d_state >= 32) { - const int64_t shmem_size = d_state / 32; - - // The final simd_sum won't work if the number of simd groups is - // larger than the size of a single simd group. If this case is - // hit at some point, the logic in the second simd_sum could be - // expanded to handle this with one more sequential simd_sum to - // collapse simd group sums another time. - GGML_ASSERT(shmem_size <= 32); - - // One thread pre element in d_state - GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); - - [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; - } - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SSM_SCAN: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c6871457908d4..4ffa56d45b1d3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1663,16 +1663,10 @@ kernel void kernel_ssm_conv_f32( device const void * src0, device const void * src1, device float * dst, - threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_conv & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { const int64_t ir = tgpig.x; const int64_t i2 = tgpig.y; const int64_t i3 = tgpig.z; @@ -1687,31 +1681,13 @@ kernel void kernel_ssm_conv_f32( device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); - float sumf = s[i0] * c[i0]; - - // Parallel sum: first sum over threads in simd group, then sum over simd - // group sums - sumf = simd_sum(sumf); + float sumf = 0.0f; - // If multiple simd groups per threadgroup, sum over simd group sums - if (sgptg > 1) { - if (tiisg == 0) { - shared[sgitg] = sumf; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - x[0] = sumf; - } - } - } else if (tiisg == 0) { - x[0] = sumf; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; } + + x[0] = sumf; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part From d20b02d106209f006eb3c9534146504a07d937f4 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 25 Jul 2025 09:11:49 -0600 Subject: [PATCH 12/12] refactor: Simplify shared memory sizing Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart Co-Authored-By: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal.m | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index aff3cc552fc71..337f7985badf3 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3171,19 +3171,12 @@ static int ggml_metal_encode_node( [encoder setBytes:&args length:sizeof(args) atIndex:8]; // One shared memory bucket for each simd group in the threadgroup + // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength if (d_state >= 32) { - const int64_t shmem_size = d_state / 32; - - // The final simd_sum won't work if the number of simd groups is - // larger than the size of a single simd group. If this case is - // hit at some point, the logic in the second simd_sum could be - // expanded to handle this with one more sequential simd_sum to - // collapse simd group sums another time. - GGML_ASSERT(shmem_size <= 32); - - // One thread pre element in d_state + GGML_ASSERT((int64_t)(d_state / 32) <= 32); + const int64_t shmem_size = 32; GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); - [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; }