Skip to content

Commit e16e24b

Browse files
committed
fix: Update logic to correctly do the multi-layer parallel sum
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 8d5a25d commit e16e24b

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group(
17671767
uint3 tpitg[[thread_position_in_threadgroup]],
17681768
ushort sgitg[[simdgroup_index_in_threadgroup]],
17691769
ushort tiisg[[thread_index_in_simdgroup]],
1770-
uint3 ntg[[threads_per_threadgroup]]) {
1770+
ushort sgptg[[simdgroups_per_threadgroup]],
1771+
uint3 tgpg[[threadgroups_per_grid]]) {
17711772

17721773
const int64_t i1 = tgpig.x;
17731774
const int64_t ir = tgpig.y; // current head
@@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group(
18021803
const float x_dt = x[0] * dt_soft_plus;
18031804
const float dA = exp(dt_soft_plus * A[0]);
18041805

1805-
threadgroup_barrier(mem_flags::mem_threadgroup);
1806-
1807-
float sumf = 0.0f;
1808-
18091806
const int64_t i = tpitg.x + i1*nc;
18101807
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
1811-
sumf += state * C[tpitg.x];
18121808
s[i] = state;
18131809

1814-
sumf = simd_sum(sumf);
1810+
// Parallel sum: This relies on the fact that this kernel will be
1811+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
1812+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
1813+
// compute y = sum({state * C[i] for i in range(d_state)}).
1814+
// To parallelize this effectively, we first use simd_sum over each SIMD
1815+
// group to compute the sum of each SIMD group, then place the result in
1816+
// the SIMD group's indexed bucket in the shared memory. We then sum
1817+
// over the individual group sums to compute the final sum.
18151818

1816-
threadgroup_barrier(mem_flags::mem_threadgroup);
1819+
// Computed for each thread
1820+
float sumf = state * C[tpitg.x];
18171821

1818-
// Use the shared buffer to hold the sum of each simd group
1822+
// Sum the threads in the simd group => simd sum
1823+
sumf = simd_sum(sumf);
1824+
1825+
// Once per simd group, place the group sum into the shared buffer
18191826
if (tiisg == 0) {
18201827
shared[sgitg] = sumf;
18211828
}
18221829

1830+
// Wait for all threads in the threadgroup to reach this point. This
1831+
// ensures that all elements of the shared buffer are populated with the
1832+
// sum of the individual simd groups.
18231833
threadgroup_barrier(mem_flags::mem_threadgroup);
18241834

1825-
// Sum the simd buckets
1826-
sumf = shared[tiisg];
1827-
sumf = simd_sum(sumf);
1835+
// Sum the simd buckets => threadgroup sum
1836+
sumf = 0.0f;
1837+
for (int64_t i0 = 0; i0 < sgptg; ++i0) {
1838+
sumf += shared[i0];
1839+
}
1840+
1841+
threadgroup_barrier(mem_flags::mem_threadgroup);
18281842

18291843
y[0] = sumf;
18301844

0 commit comments

Comments
 (0)