@@ -28,20 +28,20 @@ template<int _headdim> struct rotary_template {
28
28
using layout = rotary_layout<headdim, NUM_CONSUMER_WARPS>;
29
29
__device__ static inline void common_setup (common_setup_args<layout> args) {
30
30
if (args.task_iter == 0 ) {
31
- args.num_iters = min (args.globals .batches , (int )(args.globals .x .batch -blockIdx .y *args.globals .batches )) * args.globals .x .depth ; // batches*heads handled by block
31
+ args.num_iters = min (args.globals .batches , (int )(args.globals .x .batch () -blockIdx .y *args.globals .batches )) * args.globals .x .depth () ; // batches*heads handled by block
32
32
}
33
33
else args.num_iters = -1 ;
34
34
}
35
35
struct producer {
36
36
__device__ static void setup (producer_setup_args<layout> args) {
37
37
warpgroup::producer_registers ();
38
38
args.state .active_warps = min ((int )NUM_CONSUMER_WARPS,
39
- (int )(args.globals .x .rows /16 - blockIdx .x *NUM_CONSUMER_WARPS));
39
+ (int )(args.globals .x .rows () /16 - blockIdx .x *NUM_CONSUMER_WARPS));
40
40
}
41
41
__device__ static void load (producer_load_args<layout> args) {
42
42
if (warpgroup::warpid () == args.iter %4 ) {
43
- kittens::coord idx = { blockIdx .y *args.globals .batches +args.iter /args.globals .x .depth ,
44
- args.iter %args.globals .x .depth ,
43
+ kittens::coord idx = { blockIdx .y *args.globals .batches +args.iter /args.globals .x .depth () ,
44
+ args.iter %args.globals .x .depth () ,
45
45
blockIdx .x *NUM_CONSUMER_WARPS,
46
46
0 };
47
47
tma::expect_bytes (args.inputs_arrived , sizeof (layout::seq_tile)*args.state .active_warps );
@@ -54,8 +54,8 @@ template<int _headdim> struct rotary_template {
54
54
}
55
55
__device__ static void store (producer_store_args<layout> args) {
56
56
if (warpgroup::warpid () == args.iter %4 ) {
57
- kittens::coord idx = { blockIdx .y *args.globals .batches +args.iter /args.globals .x .depth ,
58
- args.iter %args.globals .x .depth ,
57
+ kittens::coord idx = { blockIdx .y *args.globals .batches +args.iter /args.globals .x .depth () ,
58
+ args.iter %args.globals .x .depth () ,
59
59
blockIdx .x *NUM_CONSUMER_WARPS,
60
60
0 };
61
61
for (int i = 0 ; i < args.state .active_warps ; i++) {
0 commit comments