-
Notifications
You must be signed in to change notification settings - Fork 12.5k
metal: SSM_SCAN performance #14743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
metal: SSM_SCAN performance #14743
Changes from 2 commits
ba74a24
8d5a25d
e16e24b
21db0b5
a5334f9
3866f76
641276a
d06d087
80545ef
16bc059
e55176a
f6d5e1a
c3711e1
d20b02d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32( | |||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: optimize (e.g. by parallelizing over d_state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kernel void kernel_ssm_scan_f32_group( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device const void * src0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device const void * src1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group( | |||||||||||||||||||||||||||||||||||||||||||||||||||
device const void * src5, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device const void * src6, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device float * dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
threadgroup float * shared [[threadgroup(0)]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
constant ggml_metal_kargs_ssm_scan & args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 tgpig[[threadgroup_position_in_grid]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 tpitg[[thread_position_in_threadgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 ntg[[threads_per_threadgroup]]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 tgpig[[threadgroup_position_in_grid]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 tpitg[[thread_position_in_threadgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
ushort sgitg[[simdgroup_index_in_threadgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
ushort tiisg[[thread_index_in_simdgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
uint3 ntg[[threads_per_threadgroup]]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t i1 = tgpig.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t ir = tgpig.y; // current head | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t i3 = tgpig.z; // current seq | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group( | |||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t ng = args.n_group; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t n_t = args.n_seq_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t s_off = args.s_off; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
device const int32_t * ids = (device const int32_t *) src6; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group( | |||||||||||||||||||||||||||||||||||||||||||||||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const float x_dt = x[0] * dt_soft_plus; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const float dA = exp(dt_soft_plus * A[0]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gabe-l-hart marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
float sumf = 0.0f; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int64_t i0 = 0; i0 < nc; ++i0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t i = i0 + i1*nc; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const float state = (s0[i] * dA) + (B[i0] * x_dt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
sumf += state * C[i0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
s[i] = state; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t i = tpitg.x + i1*nc; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
sumf += state * C[tpitg.x]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
s[i] = state; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer. I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible? Unless I'm misunderstanding the memory model, each thread only handles a single state (as in I think this would only affect prompt processing speed, not really small batches, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting thought! I'm also very much still learning the memory model, so I'll play with this idea and see how far I can get it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great suggestion. It was easy to implement and gives a nice bump in performance. Will commit and push shortly
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
sumf = simd_sum(sumf); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Use the shared buffer to hold the sum of each simd group | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (tiisg == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
shared[sgitg] = sumf; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Sum the simd buckets | ||||||||||||||||||||||||||||||||||||||||||||||||||||
sumf = shared[tiisg]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
sumf = simd_sum(sumf); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
y[0] = sumf; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// recurse | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.