Skip to content

Commit ce32b31

Browse files
committed
fix kernel
1 parent a8138c8 commit ce32b31

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

kernels/rotary/pc.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ template<int _headdim> struct rotary_template {
2828
using layout = rotary_layout<headdim, NUM_CONSUMER_WARPS>;
2929
__device__ static inline void common_setup(common_setup_args<layout> args) {
3030
if(args.task_iter == 0) {
31-
args.num_iters = min(args.globals.batches, (int)(args.globals.x.batch-blockIdx.y*args.globals.batches)) * args.globals.x.depth; // batches*heads handled by block
31+
args.num_iters = min(args.globals.batches, (int)(args.globals.x.batch()-blockIdx.y*args.globals.batches)) * args.globals.x.depth(); // batches*heads handled by block
3232
}
3333
else args.num_iters = -1;
3434
}
3535
struct producer {
3636
__device__ static void setup(producer_setup_args<layout> args) {
3737
warpgroup::producer_registers();
3838
args.state.active_warps = min((int)NUM_CONSUMER_WARPS,
39-
(int)(args.globals.x.rows/16 - blockIdx.x*NUM_CONSUMER_WARPS));
39+
(int)(args.globals.x.rows()/16 - blockIdx.x*NUM_CONSUMER_WARPS));
4040
}
4141
__device__ static void load(producer_load_args<layout> args) {
4242
if(warpgroup::warpid() == args.iter%4) {
43-
kittens::coord idx = { blockIdx.y*args.globals.batches+args.iter/args.globals.x.depth,
44-
args.iter%args.globals.x.depth,
43+
kittens::coord idx = { blockIdx.y*args.globals.batches+args.iter/args.globals.x.depth(),
44+
args.iter%args.globals.x.depth(),
4545
blockIdx.x*NUM_CONSUMER_WARPS,
4646
0 };
4747
tma::expect_bytes(args.inputs_arrived, sizeof(layout::seq_tile)*args.state.active_warps);
@@ -54,8 +54,8 @@ template<int _headdim> struct rotary_template {
5454
}
5555
__device__ static void store(producer_store_args<layout> args) {
5656
if(warpgroup::warpid() == args.iter%4) {
57-
kittens::coord idx = { blockIdx.y*args.globals.batches+args.iter/args.globals.x.depth,
58-
args.iter%args.globals.x.depth,
57+
kittens::coord idx = { blockIdx.y*args.globals.batches+args.iter/args.globals.x.depth(),
58+
args.iter%args.globals.x.depth(),
5959
blockIdx.x*NUM_CONSUMER_WARPS,
6060
0 };
6161
for(int i = 0; i < args.state.active_warps; i++) {

0 commit comments

Comments
 (0)