@@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group(
1767
1767
uint3 tpitg[[thread_position_in_threadgroup]],
1768
1768
ushort sgitg[[simdgroup_index_in_threadgroup]],
1769
1769
ushort tiisg[[thread_index_in_simdgroup]],
1770
- uint3 ntg[[threads_per_threadgroup]]) {
1770
+ ushort sgptg[[simdgroups_per_threadgroup]],
1771
+ uint3 tgpg[[threadgroups_per_grid]]) {
1771
1772
1772
1773
const int64_t i1 = tgpig.x ;
1773
1774
const int64_t ir = tgpig.y ; // current head
@@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group(
1802
1803
const float x_dt = x[0 ] * dt_soft_plus;
1803
1804
const float dA = exp (dt_soft_plus * A[0 ]);
1804
1805
1805
- threadgroup_barrier (mem_flags::mem_threadgroup);
1806
-
1807
- float sumf = 0 .0f ;
1808
-
1809
1806
const int64_t i = tpitg.x + i1*nc;
1810
1807
const float state = (s0[i] * dA) + (B[tpitg.x ] * x_dt);
1811
- sumf += state * C[tpitg.x ];
1812
1808
s[i] = state;
1813
1809
1814
- sumf = simd_sum (sumf);
1810
+ // Parallel sum: This relies on the fact that this kernel will be
1811
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1812
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1813
+ // compute y = sum({state * C[i] for i in range(d_state)}).
1814
+ // To parallelize this effectively, we first use simd_sum over each SIMD
1815
+ // group to compute the sum of each SIMD group, then place the result in
1816
+ // the SIMD group's indexed bucket in the shared memory. We then sum
1817
+ // over the individual group sums to compute the final sum.
1815
1818
1816
- threadgroup_barrier (mem_flags::mem_threadgroup);
1819
+ // Computed for each thread
1820
+ float sumf = state * C[tpitg.x ];
1817
1821
1818
- // Use the shared buffer to hold the sum of each simd group
1822
+ // Sum the threads in the simd group => simd sum
1823
+ sumf = simd_sum (sumf);
1824
+
1825
+ // Once per simd group, place the group sum into the shared buffer
1819
1826
if (tiisg == 0 ) {
1820
1827
shared[sgitg] = sumf;
1821
1828
}
1822
1829
1830
+ // Wait for all threads in the threadgroup to reach this point. This
1831
+ // ensures that all elements of the shared buffer are populated with the
1832
+ // sum of the individual simd groups.
1823
1833
threadgroup_barrier (mem_flags::mem_threadgroup);
1824
1834
1825
- // Sum the simd buckets
1826
- sumf = shared[tiisg];
1827
- sumf = simd_sum (sumf);
1835
+ // Sum the simd buckets => threadgroup sum
1836
+ sumf = 0 .0f ;
1837
+ for (int64_t i0 = 0 ; i0 < sgptg; ++i0) {
1838
+ sumf += shared[i0];
1839
+ }
1840
+
1841
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1828
1842
1829
1843
y[0 ] = sumf;
1830
1844
0 commit comments