Skip to content

Commit 0b0de0b

Browse files
metal: SSM_SCAN performance (llama/14743)
* feat: Add s_off as a parameter in the args struct This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update logic to correctly do the multi-layer parallel sum Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Correctly size the shared memory bufer and assert expected size relationships Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Compute block offsets once rather than once per token Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Use local variable for state recursion Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Use a secondary simd_sum instead of a for loop Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add assertion and comment about relationship between simd size and num simd groups Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Parallelize of d_state for mamba-1 Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Parallel sum in SSM_CONV Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Revert "feat: Parallel sum in SSM_CONV" After discussion with @compilade, the size of the parallelism here is not worth the cost in complexity or overhead of the parallel for. ggml-org/llama.cpp#14743 (comment) This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Simplify shared memory sizing Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-Authored-By: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent d414c3f commit 0b0de0b

File tree

3 files changed

+156
-42
lines changed

3 files changed

+156
-42
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ typedef struct {
528528
int64_t n_group;
529529
int64_t n_seq_tokens;
530530
int64_t n_seqs;
531+
int64_t s_off;
531532
uint64_t nb01;
532533
uint64_t nb02;
533534
uint64_t nb03;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
31413141
/*.n_group =*/ n_group,
31423142
/*.n_seq_tokens =*/ n_seq_tokens,
31433143
/*.n_seqs =*/ n_seqs,
3144+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
31443145
/*.nb01 =*/ nb01,
31453146
/*.nb02 =*/ nb02,
31463147
/*.nb03 =*/ nb03,
@@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
31693170
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
31703171
[encoder setBytes:&args length:sizeof(args) atIndex:8];
31713172

3173+
// One shared memory bucket for each simd group in the threadgroup
3174+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3175+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3176+
if (d_state >= 32) {
3177+
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3178+
const int64_t shmem_size = 32;
3179+
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3180+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3181+
}
3182+
31723183
if (ne30 == 1) {
31733184
// Mamba-2
3174-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3185+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
31753186
} else {
31763187
GGML_ASSERT(d_inner == 1);
3177-
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3188+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
31783189
}
31793190
} break;
31803191
case GGML_OP_RWKV_WKV6:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 142 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
18231823
device const void * src5,
18241824
device const void * src6,
18251825
device float * dst,
1826+
threadgroup float * shared [[threadgroup(0)]],
18261827
constant ggml_metal_kargs_ssm_scan & args,
1827-
uint3 tgpig[[threadgroup_position_in_grid]],
1828-
uint3 tpitg[[thread_position_in_threadgroup]],
1829-
uint3 ntg[[threads_per_threadgroup]]) {
1828+
uint3 tgpig[[threadgroup_position_in_grid]],
1829+
uint3 tpitg[[thread_position_in_threadgroup]],
1830+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1831+
ushort tiisg[[thread_index_in_simdgroup]],
1832+
ushort sgptg[[simdgroups_per_threadgroup]],
1833+
uint3 tgpg[[threadgroups_per_grid]]) {
1834+
1835+
const int64_t i0 = tpitg.x;
18301836
const int64_t i1 = 0;
18311837
const int64_t ir = tgpig.x; // current head
18321838
const int64_t i3 = tgpig.y; // current seq
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
18411847
const int64_t ng = args.n_group;
18421848
const int64_t n_t = args.n_seq_tokens;
18431849

1844-
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1850+
const int64_t s_off = args.s_off;
18451851

18461852
device const int32_t * ids = (device const int32_t *) src6;
18471853

1848-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1849-
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1854+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1855+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1856+
const int64_t i = i0 + i1*nc;
1857+
float s0 = s0_buff[i];
1858+
float s = s_buff[i];
1859+
1860+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1861+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1862+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1863+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1864+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1865+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
18501866

18511867
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1852-
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1853-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1854-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1855-
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1856-
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1857-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1868+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1869+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1870+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1871+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1872+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
18581873

18591874
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
18601875
const float x_dt = x[0] * dt_soft_plus;
1861-
float sumf = 0.0f;
18621876

1863-
for (int64_t i0 = 0; i0 < nc; ++i0) {
1864-
const int64_t i = i0 + i1*nc;
1865-
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1866-
sumf += state * C[i0];
1867-
s[i] = state;
1868-
}
1877+
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1878+
s = state;
1879+
1880+
// Parallel sum: This relies on the fact that this kernel will be
1881+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
1882+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
1883+
// compute y = sum({state * C[i] for i in range(d_state)}).
1884+
// To parallelize this effectively, we first use simd_sum over each SIMD
1885+
// group to compute the sum of each SIMD group, then place the result in
1886+
// the SIMD group's indexed bucket in the shared memory. We then sum
1887+
// over the individual group sums to compute the final sum.
1888+
1889+
// Computed for each thread
1890+
float sumf = state * C[i0];
18691891

1870-
y[0] = sumf;
1892+
// Sum the threads in the simd group => simd sum
1893+
sumf = simd_sum(sumf);
1894+
1895+
if (sgptg > 1) {
1896+
1897+
// Once per simd group, place the group sum into the shared buffer
1898+
if (tiisg == 0) {
1899+
shared[sgitg] = sumf;
1900+
}
1901+
1902+
// Wait for all threads in the threadgroup to reach this point. This
1903+
// ensures that all elements of the shared buffer are populated with the
1904+
// sum of the individual simd groups.
1905+
threadgroup_barrier(mem_flags::mem_threadgroup);
1906+
1907+
// For simd group 0 at indices < num simd groups, extract the shared
1908+
// simd sum
1909+
sumf = 0.0f;
1910+
if (sgitg == 0) {
1911+
if (tiisg < sgptg) {
1912+
sumf = shared[tiisg];
1913+
}
1914+
sumf = simd_sum(sumf);
1915+
if (tiisg == 0) {
1916+
y[0] = sumf;
1917+
}
1918+
}
1919+
} else if (tiisg == 0) {
1920+
y[0] = sumf;
1921+
}
18711922

18721923
// recurse
18731924
s0 = s;
18741925
}
1926+
1927+
// Assign the final state to the output buffer
1928+
s_buff[i] = s;
18751929
}
18761930

18771931
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1878-
// TODO: optimize (e.g. by parallelizing over d_state)
18791932
kernel void kernel_ssm_scan_f32_group(
18801933
device const void * src0,
18811934
device const void * src1,
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
18851938
device const void * src5,
18861939
device const void * src6,
18871940
device float * dst,
1941+
threadgroup float * shared [[threadgroup(0)]],
18881942
constant ggml_metal_kargs_ssm_scan & args,
1889-
uint3 tgpig[[threadgroup_position_in_grid]],
1890-
uint3 tpitg[[thread_position_in_threadgroup]],
1891-
uint3 ntg[[threads_per_threadgroup]]) {
1943+
uint3 tgpig[[threadgroup_position_in_grid]],
1944+
uint3 tpitg[[thread_position_in_threadgroup]],
1945+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1946+
ushort tiisg[[thread_index_in_simdgroup]],
1947+
ushort sgptg[[simdgroups_per_threadgroup]],
1948+
uint3 tgpg[[threadgroups_per_grid]]) {
1949+
1950+
const int64_t i0 = tpitg.x;
18921951
const int64_t i1 = tgpig.x;
18931952
const int64_t ir = tgpig.y; // current head
18941953
const int64_t i3 = tgpig.z; // current seq
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
19031962
const int64_t ng = args.n_group;
19041963
const int64_t n_t = args.n_seq_tokens;
19051964

1906-
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1965+
const int64_t s_off = args.s_off;
19071966

19081967
device const int32_t * ids = (device const int32_t *) src6;
19091968

1910-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1911-
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1969+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1970+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1971+
const int64_t i = i0 + i1*nc;
1972+
float s0 = s0_buff[i];
1973+
float s = s_buff[i];
1974+
1975+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1976+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1977+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1978+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1979+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1980+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
19121981

19131982
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1914-
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1915-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1916-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1917-
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1918-
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1919-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1983+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1984+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1985+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1986+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1987+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
19201988

19211989
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
19221990
const float x_dt = x[0] * dt_soft_plus;
19231991
const float dA = exp(dt_soft_plus * A[0]);
1924-
float sumf = 0.0f;
19251992

1926-
for (int64_t i0 = 0; i0 < nc; ++i0) {
1927-
const int64_t i = i0 + i1*nc;
1928-
const float state = (s0[i] * dA) + (B[i0] * x_dt);
1929-
sumf += state * C[i0];
1930-
s[i] = state;
1993+
const float state = (s0 * dA) + (B[i0] * x_dt);
1994+
s = state;
1995+
1996+
// Parallel sum: This relies on the fact that this kernel will be
1997+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
1998+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
1999+
// compute y = sum({state * C[i] for i in range(d_state)}).
2000+
// To parallelize this effectively, we first use simd_sum over each SIMD
2001+
// group to compute the sum of each SIMD group, then place the result in
2002+
// the SIMD group's indexed bucket in the shared memory. We then sum
2003+
// over the individual group sums to compute the final sum.
2004+
2005+
// Computed for each thread
2006+
float sumf = state * C[i0];
2007+
2008+
// Sum the threads in the simd group => simd sum
2009+
sumf = simd_sum(sumf);
2010+
2011+
// Once per simd group, place the group sum into the shared buffer
2012+
if (tiisg == 0) {
2013+
shared[sgitg] = sumf;
19312014
}
19322015

1933-
y[0] = sumf;
2016+
// Wait for all threads in the threadgroup to reach this point. This
2017+
// ensures that all elements of the shared buffer are populated with the
2018+
// sum of the individual simd groups.
2019+
threadgroup_barrier(mem_flags::mem_threadgroup);
2020+
2021+
// For simd group 0 at indices < num simd groups, extract the shared
2022+
// simd sum
2023+
sumf = 0.0f;
2024+
if (sgitg == 0) {
2025+
if (tiisg < sgptg) {
2026+
sumf = shared[tiisg];
2027+
}
2028+
sumf = simd_sum(sumf);
2029+
if (tiisg == 0) {
2030+
y[0] = sumf;
2031+
}
2032+
}
19342033

19352034
// recurse
19362035
s0 = s;
19372036
}
2037+
2038+
// Assign the final state to the output buffer
2039+
s_buff[i] = s;
19382040
}
19392041

19402042
kernel void kernel_rwkv_wkv6_f32(

0 commit comments

Comments
 (0)