Skip to content

Adding the 128 head template and changing makefile for python version #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: mla
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions kernels/attn/demo/mla_decode/Makefile

This file was deleted.

19 changes: 18 additions & 1 deletion kernels/attn/demo/mla_decode/template_mla_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace kittens;
using namespace kittens::prototype;
using namespace kittens::prototype::interpreter;

static constexpr int QKRot_D = 64, QVO_D = 512, QVO_Dd2 = QVO_D/2, NUM_ROWS = 32, PAGE_SIZE = 256;
static constexpr int QKRot_D = 64, QVO_D = 512, QVO_Dd2 = QVO_D/2, NUM_ROWS = 32, PAGE_SIZE = 64;
using qrot_tile = st_bf<64, QKRot_D>;
using qvo_tile = st_bf<64, QVO_D>;
using q_global = kittens::gl<bf16, -1, -1, -1, QKRot_D, qrot_tile>; // B * R * H * D_QKRot_D
Expand Down Expand Up @@ -523,6 +523,23 @@ PYBIND11_MODULE(mla_decode, m) {
&config<8>::globals::tic
#ifdef KITTENS_TIMINGS
, &config<8>::globals::timings
#endif
);
kittens::py::bind_kernel<interpreter::kernel<config<128>, partial_template<128>, reduction_template<128>>>(m, "mla_decode_128_heads",
&config<128>::globals::instructions,
&config<128>::globals::Q,
&config<128>::globals::QV,
&config<128>::globals::K_cache,
&config<128>::globals::V_cache,
&config<128>::globals::Table,
&config<128>::globals::O,
&config<128>::globals::O_scratch,
&config<128>::globals::Lvec_scratch,
&config<128>::globals::semaphore,
&config<128>::globals::Softmax_scale,
&config<128>::globals::tic
#ifdef KITTENS_TIMINGS
, &config<128>::globals::timings
#endif
);
m.def("__get_quality__", &get_quality,
Expand Down
78 changes: 57 additions & 21 deletions kernels/attn/demo/mla_decode/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
torch.manual_seed(0)

D_Main, D_Rot = 512, 64
PAGE_SIZE = 256
PAGE_SIZE = 64#256
# H = 16 # set by q_heads
NUM_PAGES = 10000 # number of pages in cache
NUM_PROCESSORS = 132 # number of processors
Expand All @@ -30,12 +30,12 @@ def init_arguments(seq_lengths: List[int], new_tokens: int, q_heads: int=16):
B = len(seq_lengths)

# Need to initialize QRot, QV, K_cache, V_cache, Lengths, Table
QRot = torch.randn(B, new_tokens, q_heads, D_Rot, dtype=torch.bfloat16, device='cuda')
QV = torch.randn(B, new_tokens, q_heads, D_Main, dtype=torch.bfloat16, device='cuda')
K_cache = torch.randn(NUM_PAGES, PAGE_SIZE, D_Rot, dtype=torch.bfloat16, device='cuda')
V_cache = torch.randn(NUM_PAGES, PAGE_SIZE, D_Main, dtype=torch.bfloat16, device='cuda')
Lengths = torch.tensor(seq_lengths, dtype=torch.int32, device='cuda')
Table = torch.randint(0, NUM_PAGES, (B, MAX_NUM_PAGES), dtype=torch.int32, device='cuda')
QRot = torch.randn(B, new_tokens, q_heads, D_Rot, dtype=torch.bfloat16, device='cuda') # B x new_tokens x q_heads x D_Rotary
QV = torch.randn(B, new_tokens, q_heads, D_Main, dtype=torch.bfloat16, device='cuda') # B x new_tokens x q_heads x D_Main
K_cache = torch.randn(NUM_PAGES, PAGE_SIZE, D_Rot, dtype=torch.bfloat16, device='cuda') # NUM_PAGES x PAGE_SIZE x D_Rotary
V_cache = torch.randn(NUM_PAGES, PAGE_SIZE, D_Main, dtype=torch.bfloat16, device='cuda') # NUM_PAGES x PAGE_SIZE x D_Main
Lengths = torch.tensor(seq_lengths, dtype=torch.int32, device='cuda') # B
Table = torch.randint(0, NUM_PAGES, (B, MAX_NUM_PAGES), dtype=torch.int32, device='cuda') # B x MAX_NUM_PAGES

return QRot, QV, K_cache, V_cache, Lengths, Table

Expand Down Expand Up @@ -89,7 +89,7 @@ def run_thundermla(QRot, QV, K_cache, V_cache, Lengths, Table, Instructions, O_s
KV_all = torch.cat([V_cache, K_cache], dim=-1).contiguous()
softmax_scale = 1.0 / math.sqrt(D_Main+D_Rot)
torch.cuda.synchronize()
mla_decode_fn = mla_decode.mla_decode_8_heads if q_heads == 8 else mla_decode.mla_decode
mla_decode_fn = mla_decode.mla_decode_128_heads if q_heads == 128 else mla_decode.mla_decode
if Timings is not None:
mla_decode_fn(Instructions, QRot, QV, K_cache, V_cache, Table, O, O_scratch, Lvec_scratch, Semaphore, softmax_scale, tic, Timings)
mla_decode_fn(Instructions, QRot, QV, K_cache, V_cache, Table, O, O_scratch, Lvec_scratch, Semaphore, softmax_scale, 1-tic, Timings)
Expand All @@ -105,7 +105,7 @@ def profile_thundermla(QRot, QV, K_cache, V_cache, Lengths, Table, Instructions,
O = torch.zeros_like(QV)
softmax_scale = 1.0 / math.sqrt(D_Main+D_Rot)
# execute once to warm up
mla_decode_fn = mla_decode.mla_decode_8_heads if q_heads == 8 else mla_decode.mla_decode
mla_decode_fn = mla_decode.mla_decode_128_heads if q_heads == 128 else mla_decode.mla_decode
if Timings is not None:
mla_decode_fn(Instructions, QRot, QV, K_cache, V_cache, Table, O, O_scratch, Lvec_scratch, Semaphore, softmax_scale, 1, Timings)
else:
Expand All @@ -124,7 +124,7 @@ def profile_thundermla(QRot, QV, K_cache, V_cache, Lengths, Table, Instructions,
def run_mla_torch(QRot, QV, K_cache, V_cache, Lengths, Table):
Q = torch.concat([QRot, QV], dim=-1)
q_heads = Q.shape[2]
full_K = torch.cat([K_cache, V_cache], dim=-1)[Table].reshape(Q.shape[0], -1, Q.shape[-1])
full_K = torch.cat([K_cache, V_cache], dim=-1)[Table].reshape(Q.shape[0], -1, Q.shape[-1]) #(B, MAX_NUM_PAGES * PAGE_SIZE, D_Rot + D_Main)
full_V = V_cache[Table].reshape(Q.shape[0], -1, QV.shape[-1])
softmax_scale = 1.0 / math.sqrt(D_Main+D_Rot)
O = torch.zeros_like(QV)
Expand Down Expand Up @@ -155,18 +155,54 @@ def main(seq_lengths, new_tokens, q_heads=16):

time_per_iter = profile_thundermla(QRot, QV, K_cache, V_cache, Lengths, Table, Instructions, O_scratch, Lvec_scratch, Semaphore, Timings)
print(f"Time per iter: {time_per_iter*1000} ms")

# Calculate TFLOPS and memory bandwidth
b = len(seq_lengths)
s_q = new_tokens
h_q = q_heads
d = D_Rot
dv = D_Main
h_kv = h_q # Assuming same number of heads for K/V as Q
total_seqlens = sum(seq_lengths)

# Calculate FLOPS and memory bytes
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(QRot.dtype).bits // 8)

# Print performance metrics
print(f"Performance: {FLOPS / 10**12 / time_per_iter:.2f} TFLOPS, {bytes / 10**9 / time_per_iter:.2f} GB/s")

# save_gantt_chart(Timings, Instructions, name='new')

if __name__ == "__main__":
main([4641,45118,1730,1696], 4, 16)
main([65536], 1, 16)
main([512]*64, 2, 16)
main([4096]*132, 4, 16)
main([871,568,711,329,617,1015,348,978,543,837,650,1020,924,679,560,497,650,406,381,423,511,423,569,943,645,820,829,883,937,765,711,847,722,546,519,279,516,315,664,845,850,546,670,871,527,329,446,764,582,1011,453,655,532,985,1019,810,317,305,949,317,669,768,530,349], 4, 16)

main([4641,45118,1730,1696], 4, 8)
main([65536], 1, 8)
main([512]*64, 2, 8)
main([4096]*132, 4, 8)
main([871,568,711,329,617,1015,348,978,543,837,650,1020,924,679,560,497,650,406,381,423,511,423,569,943,645,820,829,883,937,765,711,847,722,546,519,279,516,315,664,845,850,546,670,871,527,329,446,764,582,1011,453,655,532,985,1019,810,317,305,949,317,669,768,530,349], 4, 8)
# main([4641,45118,1730,1696], 4, 16)

# Set the CUDA device (e.g., device 0)
CUDA_DEVICE = 3
torch.cuda.set_device(CUDA_DEVICE)
print(f"Using CUDA device: {torch.cuda.get_device_name(CUDA_DEVICE)}")

batch_size = 32
new_tokens = 1
heads = 128

for seq_len in [1024, 2048, 4096, 8192, 8192*2, 8192*4]:
seq_lengths = [seq_len] * batch_size
print(f' ----------- starting seq_lengths: {seq_lengths} new_tokens: {new_tokens} q_heads: {heads} -----------')
main([seq_len]*batch_size, new_tokens, heads)

# for seq_len in [1024, 2048, 4096, 8192, 8192*2, 8192*4]:
# # Create a list of identical sequence lengths
# seq_lengths = [seq_len] * batch_size
# main(seq_lengths, new_tokens, heads)

# main([65536], 1, 16)
# main([512]*64, 2, 16)
# main([4096]*132, 4, 16)
# main([871,568,711,329,617,1015,348,978,543,837,650,1020,924,679,560,497,650,406,381,423,511,423,569,943,645,820,829,883,937,765,711,847,722,546,519,279,516,315,664,845,850,546,670,871,527,329,446,764,582,1011,453,655,532,985,1019,810,317,305,949,317,669,768,530,349], 4, 16)

# main([4641,45118,1730,1696], 4, 8)
# main([65536], 1, 8)
# main([512]*64, 2, 8)
# main([4096]*132, 4, 8)
# main([871,568,711,329,617,1015,348,978,543,837,650,1020,924,679,560,497,650,406,381,423,511,423,569,943,645,820,829,883,937,765,711,847,722,546,519,279,516,315,664,845,850,546,670,871,527,329,446,764,582,1011,453,655,532,985,1019,810,317,305,949,317,669,768,530,349], 4, 8)