Skip to content

Conversation

chengjunlu
Copy link
Contributor

The layout propagation across the scf.for op in RemoveLayout is not implemented well for these aspects:

  1. There is not analysis on the cost model of using different layout for the operations. (Choosing different tiling pattern for Triton ops.). It only rely on the anchors in ad-hoc.
  2. It is not implemented well for ops with multiple results ops.
  3. It is not implemented well for ops with nested basic blocks.
  4. The remove layout doesn't support to propagate the layout through the scf.for ops.

With the limitations, the scf.for operation is the bottle neck of the efficient after the remove layout pass.
This is not issue on NV GPU because the NV GPU convert the layout convert operations to async.cp in software pipeline.

But it is an issue for Intel GPU. We rely on the remove layout to get a simple program with less convert layout operations.

Plan to enhance the remove layout to enhance the limitations of the remove layout.

  1. Refactor the implementation of remove layout to support ops with multiple results and nested basic blocks well.
  2. Support the propagate layout through the scf.for ops on demand.
  3. Add an cost model analysis pass to get an costs of the different tiling patterns across the kernel program.

This is an PR for CI.

@chengjunlu chengjunlu linked an issue Jun 18, 2025 that may be closed by this pull request
@chengjunlu chengjunlu force-pushed the chengjun/enhance_remove_layout branch from 486ed4a to f42bd66 Compare June 18, 2025 07:16
@etiotto etiotto marked this pull request as draft June 18, 2025 18:06
@etiotto etiotto requested a review from Copilot August 21, 2025 17:12
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the remove layout implementation to better handle layout propagation across scf.for operations, addressing limitations that create performance bottlenecks on Intel GPU. The changes focus on reducing duplicated layout conversion operations by improving support for multi-result operations and nested basic blocks.

  • Adds support for propagating layouts through scf.for operations with a new includeForOp parameter
  • Refactors mappedValues to handle multiple attribute mappings per value instead of single mappings
  • Includes debug output and unreachable code handling for scf.for operations

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
Utility.h Adds includeForOp parameter to getConvertBackwardSlice function signature
Utility.cpp Implements scf.for layout propagation logic with early return check and debug output
RemoveLayoutConversions.cpp Updates data structures to support multiple encodings per value and enables scf.for processing

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@chengjunlu chengjunlu force-pushed the chengjun/enhance_remove_layout branch from f42bd66 to 7a8f858 Compare September 2, 2025 13:03
@chengjunlu
Copy link
Contributor Author

chengjunlu commented Sep 2, 2025

The flex attn backward ttgir has been simplified by these changes.

There are only two root tiling layout of the dpas and the transpose of dot of dpas.

Another major in-efficient issue on Xe-Xe3 is that the regular pointer under different layout like:

%152 = tt.load %kT_ptrs_79, %150, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
%153 = tt.load %kT_ptrs_81, %151, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
%154 = tt.trans %153 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
The simplified ttgir
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0], [0, 0]], block = []}>
#loc1 = loc("arg_Q")
#loc2 = loc("arg_K")
#loc3 = loc("arg_V")
#loc4 = loc("arg_LSE")
#loc5 = loc("arg_DELTA")
#loc6 = loc("arg_DO")
#loc7 = loc("arg_DQ")
#loc8 = loc("arg_DV")
#loc9 = loc("arg_KV_NUM_BLKS")
#loc10 = loc("arg_KV_IDX")
#loc11 = loc("arg_Q_NUM_BLKS")
#loc12 = loc("arg_Q_IDX")
#loc13 = loc("arg_FULL_KV_NUM_BLKS")
#loc14 = loc("arg_FULL_KV_IDX")
#loc15 = loc("arg_FULL_Q_NUM_BLKS")
#loc16 = loc("arg_FULL_Q_IDX")
#loc17 = loc("in_ptr16")
#loc18 = loc("out_ptr0")
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [1, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_block_scale_dpas, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
  tt.func public @triton_tem_fused_zeros_1(%arg_Q: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_Q"), %arg_K: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_K"), %arg_V: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_V"), %arg_LSE: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("arg_LSE"), %arg_DELTA: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("arg_DELTA"), %arg_DO: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_DO"), %arg_DQ: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_DQ"), %arg_DV: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("arg_DV"), %arg_KV_NUM_BLKS: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"), %arg_KV_IDX: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_KV_IDX"), %arg_Q_NUM_BLKS: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"), %arg_Q_IDX: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_Q_IDX"), %arg_FULL_KV_NUM_BLKS: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"), %arg_FULL_KV_IDX: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"), %arg_FULL_Q_NUM_BLKS: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"), %arg_FULL_Q_IDX: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"), %in_ptr16: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("in_ptr16"), %out_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("out_ptr0")) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #linear> loc(#loc)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
    %cst_2 = arith.constant dense<1023> : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %cst_3 = arith.constant dense<1023> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %cst_4 = arith.constant dense<0xFF800000> : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %cst_5 = arith.constant dense<0xFF800000> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %cst_6 = arith.constant dense<0.000000e+00> : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %cst_8 = arith.constant dense<64> : tensor<64x1xi32, #linear> loc(#loc)
    %cst_9 = arith.constant dense<64> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
    %cst_10 = arith.constant dense<64> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
    %cst_11 = arith.constant dense<64> : tensor<64x1xi32, #mma> loc(#loc)
    %cst_12 = arith.constant dense<1023> : tensor<64x1xi32, #mma> loc(#loc)
    %cst_13 = arith.constant dense<1023> : tensor<64x1xi32, #linear> loc(#loc)
    %cst_14 = arith.constant dense<1023> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
    %cst_15 = arith.constant dense<1023> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
    %cst_16 = arith.constant dense<64> : tensor<1x64xi32, #linear> loc(#loc)
    %cst_17 = arith.constant dense<64> : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
    %cst_18 = arith.constant dense<64> : tensor<1x64xi32, #mma> loc(#loc)
    %cst_19 = arith.constant dense<1023> : tensor<1x64xi32, #linear> loc(#loc)
    %cst_20 = arith.constant dense<1023> : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
    %cst_21 = arith.constant dense<1023> : tensor<1x64xi32, #mma> loc(#loc)
    %cst_22 = arith.constant dense<false> : tensor<64x64xi1, #mma> loc(#loc)
    %c8_i32 = arith.constant 8 : i32 loc(#loc)
    %c261888_i32 = arith.constant 261888 : i32 loc(#loc)
    %c65472_i32 = arith.constant 65472 : i32 loc(#loc)
    %c64_i32 = arith.constant 64 : i32 loc(#loc)
    %c1_i32 = arith.constant 1 : i32 loc(#loc)
    %c2_i32 = arith.constant 2 : i32 loc(#loc)
    %c4_i32 = arith.constant 4 : i32 loc(#loc)
    %c1023_i32 = arith.constant 1023 : i32 loc(#loc)
    %c128_i32 = arith.constant 128 : i32 loc(#loc)
    %c16_i32 = arith.constant 16 : i32 loc(#loc)
    %c0_i32 = arith.constant 0 : i32 loc(#loc)
    %cst_23 = arith.constant dense<1.250000e-01> : tensor<64x64xf32, #mma> loc(#loc)
    %cst_24 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> loc(#loc)
    %cst_25 = arith.constant dense<0xFF800000> : tensor<64x64xf32, #mma> loc(#loc)
    %cst_26 = arith.constant dense<1.44269502> : tensor<64x64xf32, #mma> loc(#loc)
    %0 = tt.get_program_id x : i32 loc(#loc)
    %1 = tt.get_program_id y : i32 loc(#loc)
    %2 = tt.get_program_id z : i32 loc(#loc)
    %3 = arith.remsi %1, %c2_i32 : i32 loc(#loc)
    %4 = arith.muli %2, %c65472_i32 : i32 loc(#loc)
    %5 = arith.muli %3, %c261888_i32 : i32 loc(#loc)
    %6 = arith.addi %4, %5 : i32 loc(#loc)
    %7 = arith.extsi %6 : i32 to i64 loc(#loc)
    %8 = arith.muli %1, %c261888_i32 : i32 loc(#loc)
    %9 = arith.addi %4, %8 : i32 loc(#loc)
    %10 = arith.extsi %9 : i32 to i64 loc(#loc)
    %11 = tt.addptr %arg_K, %7 : !tt.ptr<f16>, i64 loc(#loc)
    %12 = tt.addptr %arg_V, %7 : !tt.ptr<f16>, i64 loc(#loc)
    %13 = tt.addptr %arg_DV, %10 : !tt.ptr<f16>, i64 loc(#loc)
    %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
    %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
    %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
    %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
    %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
    %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %21 = arith.cmpi sge, %0, %c16_i32 : i32 loc(#loc)
    scf.if %21 {
      %22 = arith.subi %0, %c16_i32 : i32 loc(#loc)
      %23 = arith.divsi %22, %c16_i32 : i32 loc(#loc)
      %24 = arith.addi %23, %2 : i32 loc(#loc)
      %25 = arith.remsi %22, %c16_i32 : i32 loc(#loc)
      %26 = arith.divsi %25, %c2_i32 : i32 loc(#loc)
      %27 = arith.muli %26, %c8_i32 : i32 loc(#loc)
      %28 = arith.muli %24, %c65472_i32 : i32 loc(#loc)
      %29 = arith.addi %28, %8 : i32 loc(#loc)
      %30 = arith.extsi %29 : i32 to i64 loc(#loc)
      %31 = arith.muli %1, %c4_i32 : i32 loc(#loc)
      %32 = arith.addi %31, %24 : i32 loc(#loc)
      %33 = arith.muli %32, %c1023_i32 : i32 loc(#loc)
      %34 = arith.extsi %33 : i32 to i64 loc(#loc)
      %35 = tt.addptr %arg_Q, %30 : !tt.ptr<f16>, i64 loc(#loc)
      %36 = tt.addptr %arg_DO, %30 : !tt.ptr<f16>, i64 loc(#loc)
      %37 = tt.addptr %arg_DQ, %30 : !tt.ptr<f16>, i64 loc(#loc)
      %38 = tt.addptr %arg_LSE, %34 : !tt.ptr<f32>, i64 loc(#loc)
      %39 = tt.addptr %arg_DELTA, %34 : !tt.ptr<f32>, i64 loc(#loc)
      %40 = arith.muli %25, %c64_i32 : i32 loc(#loc)
      %41 = tt.splat %40 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %42 = tt.splat %40 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %43 = arith.addi %41, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %44 = arith.addi %42, %17 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %dq = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc25)
      %dq_27 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<64x1xi32, #mma> loc(#loc25)
      %45 = arith.muli %dq, %cst_10 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %46 = arith.muli %dq_27, %cst_11 : tensor<64x1xi32, #mma> loc(#loc)
      %47 = tt.splat %35 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %48 = tt.addptr %47, %45 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %50 = tt.expand_dims %49 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %51 = tt.expand_dims %20 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc)
      %52 = tt.broadcast %48 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %53 = tt.broadcast %50 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %54 = tt.broadcast %51 : tensor<1x64xi32, #mma> -> tensor<64x64xi32, #mma> loc(#loc)
      %55 = tt.addptr %52, %53 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %56 = arith.cmpi slt, %dq_27, %cst_12 : tensor<64x1xi32, #mma> loc(#loc)
      %57 = arith.cmpi slt, %dq, %cst_15 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %58 = tt.broadcast %56 : tensor<64x1xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc)
      %59 = tt.broadcast %57 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %60 = tt.load %55, %59, %cst_1 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %61 = tt.splat %36 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %62 = tt.addptr %61, %45 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %63 = tt.broadcast %62 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %64 = tt.addptr %63, %53 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %65 = tt.load %64, %59, %cst_1 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %66 = arith.cmpi slt, %44, %cst_2 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %67 = tt.splat %39 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %68 = tt.addptr %67, %44 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %69 = tt.load %68, %66 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %70 = tt.splat %38 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %71 = tt.addptr %70, %44 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %72 = tt.load %71, %66 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %73 = arith.cmpf oeq, %72, %cst_4 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %74 = arith.select %73, %cst_6, %72 : tensor<64xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %75 = tt.expand_dims %74 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<64x1xf32, #mma> loc(#loc)
      %76 = tt.addptr %arg_KV_IDX, %27 : !tt.ptr<i32>, i32 loc(#loc)
      %77 = tt.load %76 : !tt.ptr<i32> loc(#loc)
      %78 = arith.muli %77, %c128_i32 : i32 loc(#loc)
      %79 = tt.addptr %arg_KV_NUM_BLKS, %26 : !tt.ptr<i32>, i32 loc(#loc)
      %80 = tt.load %79 : !tt.ptr<i32> loc(#loc)
      %81 = tt.splat %78 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %82 = tt.splat %78 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %83 = tt.splat %78 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %84 = arith.addi %81, %18 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %85 = arith.addi %82, %19 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %86 = arith.addi %83, %20 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %87 = tt.expand_dims %84 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %88 = tt.expand_dims %85 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
      %89 = arith.muli %87, %cst_17 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %90 = arith.muli %88, %cst_16 : tensor<1x64xi32, #linear> loc(#loc)
      %91 = tt.splat %11 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %92 = tt.splat %11 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #linear> loc(#loc)
      %93 = tt.addptr %91, %89 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %94 = tt.addptr %92, %90 : tensor<1x64x!tt.ptr<f16>, #linear>, tensor<1x64xi32, #linear> loc(#loc)
      %95 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %96 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
      %97 = tt.broadcast %93 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %98 = tt.broadcast %94 : tensor<1x64x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %99 = tt.broadcast %95 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %100 = tt.broadcast %96 : tensor<64x1xi32, #linear> -> tensor<64x64xi32, #linear> loc(#loc)
      %kT_ptrs = tt.addptr %97, %99 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc20)
      %kT_ptrs_28 = tt.addptr %98, %100 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc20)
      %101 = tt.splat %12 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %102 = tt.addptr %101, %89 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %103 = tt.broadcast %102 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %vT_ptrs = tt.addptr %103, %99 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc21)
      %104 = arith.muli %80, %c2_i32 : i32 loc(#loc)
      %105 = arith.minsi %104, %c16_i32 : i32 loc(#loc)
      %106 = arith.cmpi sge, %105, %c1_i32 : i32 loc(#loc)
      %107 = scf.if %106 -> (tensor<64x64xf32, #mma>) {
        %142 = arith.subi %105, %c1_i32 : i32 loc(#loc)
        %dq_32 = arith.remsi %dq_27, %cst_12 : tensor<64x1xi32, #mma> loc(#loc25)
        %dq_33 = tt.splat %in_ptr16 : !tt.ptr<i64> -> tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc25)
        %dq_34 = tt.addptr %dq_33, %dq_32 : tensor<64x1x!tt.ptr<i64>, #mma>, tensor<64x1xi32, #mma> loc(#loc25)
        %dq_35 = tt.splat %in_ptr16 : !tt.ptr<i64> -> tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc25)
        %dq_36 = tt.broadcast %75 : tensor<64x1xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_37 = tt.expand_dims %69 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<64x1xf32, #mma> loc(#loc25)
        %dq_38 = tt.broadcast %dq_37 : tensor<64x1xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc25)
        %vT_ptrs_39 = arith.cmpi sgt, %142, %c0_i32 : i32 loc(#loc37)
        %vT_ptrs_40 = tt.splat %vT_ptrs_39 : i1 -> tensor<64x1xi1, #mma> loc(#loc37)
        %143 = tt.load %dq_34, %vT_ptrs_40 : tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc)
        %144 = tt.broadcast %143 : tensor<64x1xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc)
        %vT_ptrs_41:7 = scf.for %vT_ptrs_78 = %c0_i32 to %142 step %c1_i32 iter_args(%arg19 = %cst_24, %kT_ptrs_79 = %kT_ptrs, %arg21 = %84, %arg22 = %86, %vT_ptrs_80 = %vT_ptrs, %kT_ptrs_81 = %kT_ptrs_28, %arg25 = %85) -> (tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>>)  : i32 {
          %145 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %146 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
          %147 = tt.expand_dims %arg22 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc)
          %148 = arith.cmpi slt, %145, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %149 = arith.cmpi slt, %146, %cst_19 : tensor<1x64xi32, #linear> loc(#loc)
          %150 = tt.broadcast %148 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %151 = tt.broadcast %149 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %152 = tt.load %kT_ptrs_79, %150, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %153 = tt.load %kT_ptrs_81, %151, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %154 = tt.trans %153 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %155 = tt.dot %60, %152, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %156 = arith.mulf %155, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
          %157 = tt.addptr %dq_35, %147 : tensor<1x64x!tt.ptr<i64>, #mma>, tensor<1x64xi32, #mma> loc(#loc)
          %158 = tt.load %157 {ttig.block_io = "row_major"} : tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc)
          %159 = tt.broadcast %158 : tensor<1x64xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc)
          %160 = arith.cmpi eq, %144, %159 : tensor<64x64xi64, #mma> loc(#loc)
          %161 = arith.select %160, %156, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc)
          %162 = arith.mulf %161, %cst_26 : tensor<64x64xf32, #mma> loc(#loc)
          %163 = arith.subf %162, %dq_36 : tensor<64x64xf32, #mma> loc(#loc)
          %164 = math.exp2 %163 : tensor<64x64xf32, #mma> loc(#loc)
          %165 = tt.load %vT_ptrs_80, %150, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %166 = tt.dot %65, %165, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %167 = arith.subf %166, %dq_38 : tensor<64x64xf32, #mma> loc(#loc)
          %168 = arith.mulf %164, %167 : tensor<64x64xf32, #mma> loc(#loc)
          %169 = arith.select %160, %168, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc)
          %170 = arith.truncf %169 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %171 = ttg.convert_layout %170 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %172 = tt.dot %171, %154, %arg19, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %173 = arith.divsi %vT_ptrs_78, %c2_i32 : i32 loc(#loc)
          %174 = tt.addptr %76, %173 : !tt.ptr<i32>, i32 loc(#loc)
          %175 = tt.load %174 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %176 = arith.addi %173, %c1_i32 : i32 loc(#loc)
          %177 = arith.cmpi slt, %176, %80 : i32 loc(#loc)
          %178 = tt.addptr %174, %c1_i32 : !tt.ptr<i32>, i32 loc(#loc)
          %179 = tt.load %178, %177 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %180 = arith.addi %vT_ptrs_78, %c1_i32 : i32 loc(#loc)
          %181 = arith.remsi %180, %c2_i32 : i32 loc(#loc)
          %182 = arith.cmpi eq, %181, %c0_i32 : i32 loc(#loc)
          %183 = arith.subi %179, %175 : i32 loc(#loc)
          %184 = arith.muli %183, %c128_i32 : i32 loc(#loc)
          %185 = arith.subi %184, %c64_i32 : i32 loc(#loc)
          %186 = arith.extui %182 : i1 to i32 loc(#loc)
          %187 = arith.muli %185, %186 : i32 loc(#loc)
          %188 = arith.subi %c1_i32, %186 : i32 loc(#loc)
          %189 = arith.muli %188, %c64_i32 : i32 loc(#loc)
          %190 = arith.addi %187, %189 : i32 loc(#loc)
          %191 = arith.muli %190, %c64_i32 : i32 loc(#loc)
          %192 = tt.splat %191 : i32 -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %193 = tt.splat %191 : i32 -> tensor<64x64xi32, #linear> loc(#loc)
          %194 = tt.addptr %kT_ptrs_79, %192 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %195 = tt.addptr %kT_ptrs_81, %193 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %196 = tt.addptr %vT_ptrs_80, %192 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %197 = tt.splat %190 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %198 = tt.splat %190 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %199 = tt.splat %190 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %200 = arith.addi %arg21, %197 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %201 = arith.addi %arg25, %198 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %202 = arith.addi %arg22, %199 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          scf.yield %172, %194, %200, %202, %196, %195, %201 : tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
        } loc(#loc37)
        %dq_42 = tt.expand_dims %vT_ptrs_41#3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc25)
        %dq_43 = tt.expand_dims %vT_ptrs_41#2 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_44 = tt.expand_dims %vT_ptrs_41#6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc25)
        %dq_45 = arith.cmpi slt, %dq_42, %cst_21 : tensor<1x64xi32, #mma> loc(#loc25)
        %dq_46 = arith.cmpi slt, %dq_43, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_47 = arith.cmpi slt, %dq_44, %cst_19 : tensor<1x64xi32, #linear> loc(#loc25)
        %dq_48 = tt.broadcast %dq_45 : tensor<1x64xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc25)
        %dq_49 = tt.broadcast %dq_46 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_50 = tt.broadcast %dq_47 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc25)
        %dq_51 = tt.load %vT_ptrs_41#1, %dq_49, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_52 = tt.load %vT_ptrs_41#5, %dq_50, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc25)
        %dq_53 = tt.trans %dq_52 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_54 = tt.dot %60, %dq_51, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_55 = arith.mulf %dq_54, %cst_23 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_56 = arith.remsi %dq_42, %cst_21 : tensor<1x64xi32, #mma> loc(#loc25)
        %dq_57 = arith.select %dq_48, %dq_55, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_58 = tt.load %dq_34 : tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc25)
        %dq_59 = tt.addptr %dq_35, %dq_56 : tensor<1x64x!tt.ptr<i64>, #mma>, tensor<1x64xi32, #mma> loc(#loc25)
        %dq_60 = tt.load %dq_59 : tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc25)
        %dq_61 = tt.broadcast %dq_58 : tensor<64x1xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc25)
        %dq_62 = tt.broadcast %dq_60 : tensor<1x64xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc25)
        %dq_63 = arith.cmpi eq, %dq_61, %dq_62 : tensor<64x64xi64, #mma> loc(#loc25)
        %dq_64 = arith.select %dq_48, %dq_63, %cst_22 : tensor<64x64xi1, #mma>, tensor<64x64xi1, #mma> loc(#loc25)
        %dq_65 = arith.select %dq_64, %dq_57, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_66 = arith.mulf %dq_65, %cst_26 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_67 = arith.subf %dq_66, %dq_36 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_68 = math.exp2 %dq_67 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_69 = tt.load %vT_ptrs_41#4, %dq_49, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_70 = tt.dot %65, %dq_69, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_71 = arith.subf %dq_70, %dq_38 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_72 = arith.mulf %dq_68, %dq_71 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_73 = arith.select %dq_48, %dq_72, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_74 = arith.select %dq_64, %dq_73, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_75 = arith.truncf %dq_74 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc25)
        %dq_76 = ttg.convert_layout %dq_75 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc25)
        %dq_77 = tt.dot %dq_76, %dq_53, %vT_ptrs_41#0, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        scf.yield %dq_77 : tensor<64x64xf32, #mma> loc(#loc)
      } else {
        scf.yield %cst_24 : tensor<64x64xf32, #mma> loc(#loc)
      } loc(#loc)
      %108 = tt.addptr %arg_FULL_KV_IDX, %27 : !tt.ptr<i32>, i32 loc(#loc)
      %109 = tt.load %108 : !tt.ptr<i32> loc(#loc)
      %110 = arith.muli %109, %c128_i32 : i32 loc(#loc)
      %111 = tt.addptr %arg_FULL_KV_NUM_BLKS, %26 : !tt.ptr<i32>, i32 loc(#loc)
      %112 = tt.load %111 : !tt.ptr<i32> loc(#loc)
      %113 = tt.splat %110 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %114 = tt.splat %110 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %115 = tt.splat %110 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %116 = arith.addi %113, %18 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %117 = arith.addi %114, %19 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %118 = arith.addi %115, %20 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %119 = tt.expand_dims %116 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %120 = tt.expand_dims %117 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
      %121 = arith.muli %119, %cst_17 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %122 = arith.muli %120, %cst_16 : tensor<1x64xi32, #linear> loc(#loc)
      %123 = tt.addptr %91, %121 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %124 = tt.addptr %92, %122 : tensor<1x64x!tt.ptr<f16>, #linear>, tensor<1x64xi32, #linear> loc(#loc)
      %125 = tt.broadcast %123 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %126 = tt.broadcast %124 : tensor<1x64x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %kT_ptrs_29 = tt.addptr %125, %99 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc20)
      %kT_ptrs_30 = tt.addptr %126, %100 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc20)
      %127 = tt.addptr %101, %121 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %128 = tt.broadcast %127 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %vT_ptrs_31 = tt.addptr %128, %99 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc21)
      %129 = arith.muli %112, %c2_i32 : i32 loc(#loc)
      %130 = arith.minsi %129, %c16_i32 : i32 loc(#loc)
      %131 = arith.cmpi sge, %130, %c1_i32 : i32 loc(#loc)
      %132 = scf.if %131 -> (tensor<64x64xf32, #mma>) {
        %142 = arith.subi %130, %c1_i32 : i32 loc(#loc)
        %dq_32 = tt.broadcast %75 : tensor<64x1xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_33 = tt.expand_dims %69 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<64x1xf32, #mma> loc(#loc25)
        %dq_34 = tt.broadcast %dq_33 : tensor<64x1xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc25)
        %vT_ptrs_35:7 = scf.for %vT_ptrs_62 = %c0_i32 to %142 step %c1_i32 iter_args(%arg19 = %107, %kT_ptrs_63 = %kT_ptrs_29, %arg21 = %116, %vT_ptrs_64 = %vT_ptrs_31, %kT_ptrs_65 = %kT_ptrs_30, %arg24 = %117, %arg25 = %118) -> (tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>)  : i32 {
          %143 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %144 = tt.expand_dims %arg24 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
          %145 = arith.cmpi slt, %143, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %146 = arith.cmpi slt, %144, %cst_19 : tensor<1x64xi32, #linear> loc(#loc)
          %147 = tt.broadcast %145 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %148 = tt.broadcast %146 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %149 = tt.load %kT_ptrs_63, %147, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %150 = tt.load %kT_ptrs_65, %148, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %151 = tt.trans %150 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %152 = tt.dot %60, %149, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %153 = arith.mulf %152, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
          %154 = arith.mulf %153, %cst_26 : tensor<64x64xf32, #mma> loc(#loc)
          %155 = arith.subf %154, %dq_32 : tensor<64x64xf32, #mma> loc(#loc)
          %156 = math.exp2 %155 : tensor<64x64xf32, #mma> loc(#loc)
          %157 = tt.load %vT_ptrs_64, %147, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %158 = tt.dot %65, %157, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %159 = arith.subf %158, %dq_34 : tensor<64x64xf32, #mma> loc(#loc)
          %160 = arith.mulf %156, %159 : tensor<64x64xf32, #mma> loc(#loc)
          %161 = arith.truncf %160 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %162 = ttg.convert_layout %161 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %163 = tt.dot %162, %151, %arg19, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %164 = arith.divsi %vT_ptrs_62, %c2_i32 : i32 loc(#loc)
          %165 = tt.addptr %108, %164 : !tt.ptr<i32>, i32 loc(#loc)
          %166 = tt.load %165 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %167 = arith.addi %164, %c1_i32 : i32 loc(#loc)
          %168 = arith.cmpi slt, %167, %112 : i32 loc(#loc)
          %169 = tt.addptr %165, %c1_i32 : !tt.ptr<i32>, i32 loc(#loc)
          %170 = tt.load %169, %168 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %171 = arith.addi %vT_ptrs_62, %c1_i32 : i32 loc(#loc)
          %172 = arith.remsi %171, %c2_i32 : i32 loc(#loc)
          %173 = arith.cmpi eq, %172, %c0_i32 : i32 loc(#loc)
          %174 = arith.subi %170, %166 : i32 loc(#loc)
          %175 = arith.muli %174, %c128_i32 : i32 loc(#loc)
          %176 = arith.subi %175, %c64_i32 : i32 loc(#loc)
          %177 = arith.extui %173 : i1 to i32 loc(#loc)
          %178 = arith.muli %176, %177 : i32 loc(#loc)
          %179 = arith.subi %c1_i32, %177 : i32 loc(#loc)
          %180 = arith.muli %179, %c64_i32 : i32 loc(#loc)
          %181 = arith.addi %178, %180 : i32 loc(#loc)
          %182 = arith.muli %181, %c64_i32 : i32 loc(#loc)
          %183 = tt.splat %182 : i32 -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %184 = tt.splat %182 : i32 -> tensor<64x64xi32, #linear> loc(#loc)
          %185 = tt.addptr %kT_ptrs_63, %183 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %186 = tt.addptr %kT_ptrs_65, %184 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %187 = tt.addptr %vT_ptrs_64, %183 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %188 = tt.splat %181 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %189 = tt.splat %181 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %190 = tt.splat %181 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %191 = arith.addi %arg21, %188 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %192 = arith.addi %arg24, %189 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %193 = arith.addi %arg25, %190 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          scf.yield %163, %185, %191, %187, %186, %192, %193 : tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
        } loc(#loc37)
        %dq_36 = tt.expand_dims %vT_ptrs_35#6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc25)
        %dq_37 = tt.expand_dims %vT_ptrs_35#2 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_38 = tt.expand_dims %vT_ptrs_35#5 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc25)
        %dq_39 = arith.cmpi slt, %dq_36, %cst_21 : tensor<1x64xi32, #mma> loc(#loc25)
        %dq_40 = arith.cmpi slt, %dq_37, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_41 = arith.cmpi slt, %dq_38, %cst_19 : tensor<1x64xi32, #linear> loc(#loc25)
        %dq_42 = tt.broadcast %dq_39 : tensor<1x64xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc25)
        %dq_43 = tt.broadcast %dq_40 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_44 = tt.broadcast %dq_41 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc25)
        %dq_45 = tt.load %vT_ptrs_35#1, %dq_43, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_46 = tt.load %vT_ptrs_35#4, %dq_44, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc25)
        %dq_47 = tt.trans %dq_46 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_48 = tt.dot %60, %dq_45, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_49 = arith.mulf %dq_48, %cst_23 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_50 = arith.select %dq_42, %dq_49, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_51 = arith.mulf %dq_50, %cst_26 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_52 = arith.subf %dq_51, %dq_32 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_53 = math.exp2 %dq_52 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_54 = tt.load %vT_ptrs_35#3, %dq_43, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc25)
        %dq_55 = tt.dot %65, %dq_54, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        %dq_56 = arith.subf %dq_55, %dq_34 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_57 = arith.mulf %dq_53, %dq_56 : tensor<64x64xf32, #mma> loc(#loc25)
        %dq_58 = arith.select %dq_42, %dq_57, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc25)
        %dq_59 = arith.truncf %dq_58 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc25)
        %dq_60 = ttg.convert_layout %dq_59 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc25)
        %dq_61 = tt.dot %dq_60, %dq_47, %vT_ptrs_35#0, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc25)
        scf.yield %dq_61 : tensor<64x64xf32, #mma> loc(#loc)
      } else {
        scf.yield %107 : tensor<64x64xf32, #mma> loc(#loc)
      } loc(#loc)
      %133 = tt.splat %37 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #mma> loc(#loc)
      %134 = tt.addptr %133, %46 : tensor<64x1x!tt.ptr<f16>, #mma>, tensor<64x1xi32, #mma> loc(#loc)
      %135 = tt.broadcast %134 : tensor<64x1x!tt.ptr<f16>, #mma> -> tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
      %136 = tt.addptr %135, %54 : tensor<64x64x!tt.ptr<f16>, #mma>, tensor<64x64xi32, #mma> loc(#loc)
      %137 = arith.mulf %132, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
      %138 = arith.cmpi slt, %51, %cst_18 : tensor<1x64xi32, #mma> loc(#loc)
      %139 = tt.broadcast %138 : tensor<1x64xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc)
      %140 = arith.andi %58, %139 : tensor<64x64xi1, #mma> loc(#loc)
      %141 = arith.truncf %137 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
      tt.store %136, %141, %140 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
    } else {
      %22 = arith.divsi %0, %c2_i32 : i32 loc(#loc)
      %23 = arith.muli %0, %c64_i32 : i32 loc(#loc)
      %24 = tt.splat %23 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %25 = tt.splat %23 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %26 = arith.addi %24, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %27 = arith.addi %25, %17 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc)
      %dv = tt.expand_dims %26 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
      %dv_27 = tt.expand_dims %27 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<64x1xi32, #mma> loc(#loc29)
      %28 = arith.muli %dv, %cst_10 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %29 = arith.muli %dv_27, %cst_11 : tensor<64x1xi32, #mma> loc(#loc)
      %30 = tt.splat %11 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %31 = tt.addptr %30, %28 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> loc(#loc)
      %33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %34 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %35 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
      %36 = tt.expand_dims %20 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc)
      %37 = tt.broadcast %31 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %38 = tt.broadcast %33 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %39 = tt.broadcast %34 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %40 = tt.broadcast %35 : tensor<1x64xi32, #linear> -> tensor<64x64xi32, #linear> loc(#loc)
      %41 = tt.broadcast %36 : tensor<1x64xi32, #mma> -> tensor<64x64xi32, #mma> loc(#loc)
      %42 = tt.addptr %37, %38 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %dv_28 = arith.cmpi slt, %dv_27, %cst_12 : tensor<64x1xi32, #mma> loc(#loc29)
      %dv_29 = arith.cmpi slt, %dv, %cst_15 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
      %dv_30 = tt.broadcast %dv_28 : tensor<64x1xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc29)
      %dv_31 = tt.broadcast %dv_29 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
      %43 = tt.load %42, %dv_31, %cst_1 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %44 = tt.splat %12 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %45 = tt.addptr %44, %28 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %46 = tt.broadcast %45 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %47 = tt.addptr %46, %38 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %48 = tt.load %47, %dv_31, %cst_1 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
      %49 = arith.muli %1, %c4_i32 : i32 loc(#loc)
      %50 = arith.addi %49, %2 : i32 loc(#loc)
      %51 = arith.muli %50, %c1023_i32 : i32 loc(#loc)
      %52 = arith.extsi %51 : i32 to i64 loc(#loc)
      %53 = tt.addptr %arg_Q, %10 : !tt.ptr<f16>, i64 loc(#loc)
      %54 = tt.addptr %arg_DO, %10 : !tt.ptr<f16>, i64 loc(#loc)
      %55 = tt.addptr %arg_LSE, %52 : !tt.ptr<f32>, i64 loc(#loc)
      %56 = tt.addptr %arg_DELTA, %52 : !tt.ptr<f32>, i64 loc(#loc)
      %57 = arith.muli %22, %c8_i32 : i32 loc(#loc)
      %58 = tt.addptr %arg_Q_IDX, %57 : !tt.ptr<i32>, i32 loc(#loc)
      %59 = tt.load %58 : !tt.ptr<i32> loc(#loc)
      %60 = arith.muli %59, %c128_i32 : i32 loc(#loc)
      %61 = tt.addptr %arg_Q_NUM_BLKS, %22 : !tt.ptr<i32>, i32 loc(#loc)
      %62 = tt.load %61 : !tt.ptr<i32> loc(#loc)
      %63 = tt.splat %60 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %64 = tt.splat %60 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %65 = tt.splat %60 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %66 = tt.splat %60 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %67 = tt.splat %60 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
      %68 = arith.addi %63, %18 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %69 = arith.addi %64, %19 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %70 = arith.addi %65, %20 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %71 = arith.addi %66, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %72 = arith.addi %67, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
      %73 = tt.expand_dims %68 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %74 = tt.expand_dims %69 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
      %75 = arith.muli %73, %cst_17 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %76 = arith.muli %74, %cst_16 : tensor<1x64xi32, #linear> loc(#loc)
      %77 = tt.splat %53 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %78 = tt.splat %53 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #linear> loc(#loc)
      %79 = tt.addptr %77, %75 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %80 = tt.addptr %78, %76 : tensor<1x64x!tt.ptr<f16>, #linear>, tensor<1x64xi32, #linear> loc(#loc)
      %81 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %82 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
      %83 = tt.broadcast %79 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %84 = tt.broadcast %80 : tensor<1x64x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %85 = tt.broadcast %81 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %86 = tt.broadcast %82 : tensor<64x1xi32, #linear> -> tensor<64x64xi32, #linear> loc(#loc)
      %qT_ptrs = tt.addptr %83, %85 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc23)
      %qT_ptrs_32 = tt.addptr %84, %86 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc23)
      %87 = tt.expand_dims %71 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %88 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
      %89 = arith.muli %87, %cst_9 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %90 = arith.muli %88, %cst_8 : tensor<64x1xi32, #linear> loc(#loc)
      %91 = tt.splat %54 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %92 = tt.splat %54 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #linear> loc(#loc)
      %93 = tt.addptr %91, %89 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %94 = tt.addptr %92, %90 : tensor<64x1x!tt.ptr<f16>, #linear>, tensor<64x1xi32, #linear> loc(#loc)
      %95 = tt.broadcast %93 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %96 = tt.broadcast %94 : tensor<64x1x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %do_ptrs = tt.addptr %95, %39 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc24)
      %do_ptrs_33 = tt.addptr %96, %40 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc24)
      %97 = arith.muli %62, %c2_i32 : i32 loc(#loc)
      %98 = arith.minsi %97, %c16_i32 : i32 loc(#loc)
      %99 = arith.cmpi sge, %98, %c1_i32 : i32 loc(#loc)
      %100:2 = scf.if %99 -> (tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>) {
        %154 = arith.subi %98, %c1_i32 : i32 loc(#loc)
        %dv_38 = tt.splat %55 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_39 = arith.remsi %dv_27, %cst_12 : tensor<64x1xi32, #mma> loc(#loc29)
        %dv_40 = tt.splat %in_ptr16 : !tt.ptr<i64> -> tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc29)
        %dv_41 = tt.splat %in_ptr16 : !tt.ptr<i64> -> tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc29)
        %dv_42 = tt.addptr %dv_41, %dv_39 : tensor<64x1x!tt.ptr<i64>, #mma>, tensor<64x1xi32, #mma> loc(#loc29)
        %dv_43 = tt.splat %56 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %do_ptrs_44 = arith.cmpi sgt, %154, %c0_i32 : i32 loc(#loc39)
        %do_ptrs_45 = tt.splat %do_ptrs_44 : i1 -> tensor<64x1xi1, #mma> loc(#loc39)
        %155 = tt.load %dv_42, %do_ptrs_45 : tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc)
        %156 = tt.broadcast %155 : tensor<64x1xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc)
        %do_ptrs_46:11 = scf.for %do_ptrs_103 = %c0_i32 to %154 step %c1_i32 iter_args(%arg19 = %cst_24, %arg20 = %cst_24, %qT_ptrs_104 = %qT_ptrs, %arg22 = %68, %arg23 = %70, %do_ptrs_105 = %do_ptrs, %arg25 = %71, %do_ptrs_106 = %do_ptrs_33, %arg27 = %72, %qT_ptrs_107 = %qT_ptrs_32, %arg29 = %69) -> (tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>>)  : i32 {
          %157 = tt.expand_dims %arg22 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %158 = tt.expand_dims %arg29 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
          %159 = tt.expand_dims %arg23 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc)
          %160 = arith.cmpi slt, %157, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %161 = arith.cmpi slt, %158, %cst_19 : tensor<1x64xi32, #linear> loc(#loc)
          %162 = tt.broadcast %160 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %163 = tt.broadcast %161 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %164 = tt.load %qT_ptrs_104, %162, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %165 = tt.load %qT_ptrs_107, %163, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %166 = tt.trans %165 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %167 = arith.cmpi slt, %arg23, %cst_3 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %168 = tt.addptr %dv_38, %arg23 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %169 = tt.load %168, %167 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %170 = arith.cmpf oeq, %169, %cst_5 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %171 = arith.select %170, %cst_7, %169 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %172 = tt.dot %43, %164, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %173 = arith.mulf %172, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
          %174 = tt.addptr %dv_40, %159 : tensor<1x64x!tt.ptr<i64>, #mma>, tensor<1x64xi32, #mma> loc(#loc)
          %175 = tt.load %174 {ttig.block_io = "row_major"} : tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc)
          %176 = tt.broadcast %175 : tensor<1x64xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc)
          %177 = arith.cmpi eq, %176, %156 : tensor<64x64xi64, #mma> loc(#loc)
          %178 = arith.select %177, %173, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc)
          %179 = arith.mulf %178, %cst_26 : tensor<64x64xf32, #mma> loc(#loc)
          %180 = tt.expand_dims %171 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc)
          %181 = tt.broadcast %180 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc)
          %182 = arith.subf %179, %181 : tensor<64x64xf32, #mma> loc(#loc)
          %183 = math.exp2 %182 : tensor<64x64xf32, #mma> loc(#loc)
          %184 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %185 = tt.expand_dims %arg27 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
          %186 = arith.cmpi slt, %184, %cst_14 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %187 = arith.cmpi slt, %185, %cst_13 : tensor<64x1xi32, #linear> loc(#loc)
          %188 = tt.broadcast %186 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %189 = tt.broadcast %187 : tensor<64x1xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %190 = tt.load %do_ptrs_105, %188, %cst_0 {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %191 = tt.load %do_ptrs_106, %189, %cst {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %192 = tt.trans %191 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %193 = arith.truncf %183 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %194 = ttg.convert_layout %193 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %195 = tt.dot %194, %190, %arg20, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %196 = tt.addptr %dv_43, %arg23 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %197 = tt.load %196, %167 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %198 = tt.dot %48, %192, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %199 = tt.expand_dims %197 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc)
          %200 = tt.broadcast %199 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc)
          %201 = arith.subf %198, %200 : tensor<64x64xf32, #mma> loc(#loc)
          %202 = arith.mulf %183, %201 : tensor<64x64xf32, #mma> loc(#loc)
          %203 = arith.select %177, %202, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc)
          %204 = arith.truncf %203 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %205 = ttg.convert_layout %204 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %206 = tt.dot %205, %166, %arg19, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %207 = arith.divsi %do_ptrs_103, %c2_i32 : i32 loc(#loc)
          %208 = tt.addptr %58, %207 : !tt.ptr<i32>, i32 loc(#loc)
          %209 = tt.load %208 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %210 = arith.addi %207, %c1_i32 : i32 loc(#loc)
          %211 = arith.cmpi slt, %210, %62 : i32 loc(#loc)
          %212 = tt.addptr %208, %c1_i32 : !tt.ptr<i32>, i32 loc(#loc)
          %213 = tt.load %212, %211 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %214 = arith.addi %do_ptrs_103, %c1_i32 : i32 loc(#loc)
          %215 = arith.remsi %214, %c2_i32 : i32 loc(#loc)
          %216 = arith.cmpi eq, %215, %c0_i32 : i32 loc(#loc)
          %217 = arith.subi %213, %209 : i32 loc(#loc)
          %218 = arith.muli %217, %c128_i32 : i32 loc(#loc)
          %219 = arith.subi %218, %c64_i32 : i32 loc(#loc)
          %220 = arith.extui %216 : i1 to i32 loc(#loc)
          %221 = arith.muli %219, %220 : i32 loc(#loc)
          %222 = arith.subi %c1_i32, %220 : i32 loc(#loc)
          %223 = arith.muli %222, %c64_i32 : i32 loc(#loc)
          %224 = arith.addi %221, %223 : i32 loc(#loc)
          %225 = arith.muli %224, %c64_i32 : i32 loc(#loc)
          %226 = tt.splat %225 : i32 -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %227 = tt.splat %225 : i32 -> tensor<64x64xi32, #linear> loc(#loc)
          %228 = tt.addptr %qT_ptrs_104, %226 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %229 = tt.addptr %qT_ptrs_107, %227 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %230 = tt.addptr %do_ptrs_105, %226 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %231 = tt.addptr %do_ptrs_106, %227 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %232 = tt.splat %224 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %233 = tt.splat %224 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %234 = tt.splat %224 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %235 = tt.splat %224 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %236 = tt.splat %224 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
          %237 = arith.addi %arg22, %232 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %238 = arith.addi %arg29, %233 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %239 = arith.addi %arg23, %234 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %240 = arith.addi %arg25, %235 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %241 = arith.addi %arg27, %236 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
          scf.yield %206, %195, %228, %237, %239, %230, %240, %231, %241, %229, %238 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
        } loc(#loc39)
        %dv_47 = tt.expand_dims %do_ptrs_46#3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_48 = tt.expand_dims %do_ptrs_46#10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc29)
        %dv_49 = tt.expand_dims %do_ptrs_46#4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc29)
        %dv_50 = arith.cmpi slt, %dv_47, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_51 = arith.cmpi slt, %dv_48, %cst_19 : tensor<1x64xi32, #linear> loc(#loc29)
        %dv_52 = tt.broadcast %dv_50 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_53 = tt.broadcast %dv_51 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc29)
        %dv_54 = tt.load %do_ptrs_46#2, %dv_52, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_55 = tt.load %do_ptrs_46#9, %dv_53, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc29)
        %dv_56 = tt.trans %dv_55 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_57 = arith.cmpi slt, %do_ptrs_46#4, %cst_3 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_58 = tt.addptr %dv_38, %do_ptrs_46#4 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_59 = tt.load %dv_58, %dv_57 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_60 = arith.cmpf oeq, %dv_59, %cst_5 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_61 = arith.select %dv_60, %cst_7, %dv_59 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_62 = tt.dot %43, %dv_54, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_63 = arith.mulf %dv_62, %cst_23 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_64 = arith.remsi %dv_49, %cst_21 : tensor<1x64xi32, #mma> loc(#loc29)
        %dv_65 = arith.select %dv_30, %dv_63, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_66 = tt.addptr %dv_40, %dv_64 : tensor<1x64x!tt.ptr<i64>, #mma>, tensor<1x64xi32, #mma> loc(#loc29)
        %dv_67 = tt.load %dv_66 : tensor<1x64x!tt.ptr<i64>, #mma> loc(#loc29)
        %dv_68 = tt.load %dv_42 : tensor<64x1x!tt.ptr<i64>, #mma> loc(#loc29)
        %dv_69 = tt.broadcast %dv_67 : tensor<1x64xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc29)
        %dv_70 = tt.broadcast %dv_68 : tensor<64x1xi64, #mma> -> tensor<64x64xi64, #mma> loc(#loc29)
        %dv_71 = arith.cmpi eq, %dv_69, %dv_70 : tensor<64x64xi64, #mma> loc(#loc29)
        %dv_72 = arith.select %dv_30, %dv_71, %cst_22 : tensor<64x64xi1, #mma>, tensor<64x64xi1, #mma> loc(#loc29)
        %dv_73 = arith.select %dv_72, %dv_65, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_74 = arith.mulf %dv_73, %cst_26 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_75 = tt.expand_dims %dv_61 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc29)
        %dv_76 = tt.broadcast %dv_75 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_77 = arith.subf %dv_74, %dv_76 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_78 = math.exp2 %dv_77 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_79 = tt.expand_dims %do_ptrs_46#6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_80 = tt.expand_dims %do_ptrs_46#8 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc29)
        %dv_81 = arith.cmpi slt, %dv_79, %cst_14 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_82 = arith.cmpi slt, %dv_80, %cst_13 : tensor<64x1xi32, #linear> loc(#loc29)
        %dv_83 = tt.broadcast %dv_81 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_84 = tt.broadcast %dv_82 : tensor<64x1xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc29)
        %dv_85 = tt.load %do_ptrs_46#5, %dv_83, %cst_0 {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_86 = tt.load %do_ptrs_46#7, %dv_84, %cst {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc29)
        %dv_87 = tt.trans %dv_86 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_88 = arith.truncf %dv_78 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc29)
        %dv_89 = ttg.convert_layout %dv_88 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
        %dv_90 = tt.dot %dv_89, %dv_85, %do_ptrs_46#1, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_91 = tt.addptr %dv_43, %do_ptrs_46#4 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_92 = tt.load %dv_91, %dv_57 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_93 = tt.dot %48, %dv_87, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_94 = tt.expand_dims %dv_92 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc29)
        %dv_95 = tt.broadcast %dv_94 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_96 = arith.subf %dv_93, %dv_95 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_97 = arith.mulf %dv_78, %dv_96 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_98 = arith.select %dv_30, %dv_97, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_99 = arith.select %dv_72, %dv_98, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_100 = arith.truncf %dv_99 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc29)
        %dv_101 = ttg.convert_layout %dv_100 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
        %dv_102 = tt.dot %dv_101, %dv_56, %do_ptrs_46#0, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        scf.yield %dv_102, %dv_90 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma> loc(#loc)
      } else {
        scf.yield %cst_24, %cst_24 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma> loc(#loc)
      } loc(#loc)
      %101 = tt.addptr %arg_FULL_Q_IDX, %57 : !tt.ptr<i32>, i32 loc(#loc)
      %102 = tt.load %101 : !tt.ptr<i32> loc(#loc)
      %103 = arith.muli %102, %c128_i32 : i32 loc(#loc)
      %104 = tt.addptr %arg_FULL_Q_NUM_BLKS, %22 : !tt.ptr<i32>, i32 loc(#loc)
      %105 = tt.load %104 : !tt.ptr<i32> loc(#loc)
      %106 = tt.splat %103 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %107 = tt.splat %103 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %108 = tt.splat %103 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %109 = tt.splat %103 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %110 = tt.splat %103 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
      %111 = arith.addi %106, %18 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %112 = arith.addi %107, %19 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
      %113 = arith.addi %108, %20 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
      %114 = arith.addi %109, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
      %115 = arith.addi %110, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
      %116 = tt.expand_dims %111 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %117 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
      %118 = arith.muli %116, %cst_17 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %119 = arith.muli %117, %cst_16 : tensor<1x64xi32, #linear> loc(#loc)
      %120 = tt.addptr %77, %118 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %121 = tt.addptr %78, %119 : tensor<1x64x!tt.ptr<f16>, #linear>, tensor<1x64xi32, #linear> loc(#loc)
      %122 = tt.broadcast %120 : tensor<1x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %123 = tt.broadcast %121 : tensor<1x64x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %qT_ptrs_34 = tt.addptr %122, %85 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc23)
      %qT_ptrs_35 = tt.addptr %123, %86 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc23)
      %124 = tt.expand_dims %114 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %125 = tt.expand_dims %115 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
      %126 = arith.muli %124, %cst_9 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %127 = arith.muli %125, %cst_8 : tensor<64x1xi32, #linear> loc(#loc)
      %128 = tt.addptr %91, %126 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %129 = tt.addptr %92, %127 : tensor<64x1x!tt.ptr<f16>, #linear>, tensor<64x1xi32, #linear> loc(#loc)
      %130 = tt.broadcast %128 : tensor<64x1x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
      %131 = tt.broadcast %129 : tensor<64x1x!tt.ptr<f16>, #linear> -> tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
      %do_ptrs_36 = tt.addptr %130, %39 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc24)
      %do_ptrs_37 = tt.addptr %131, %40 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc24)
      %132 = arith.muli %105, %c2_i32 : i32 loc(#loc)
      %133 = arith.minsi %132, %c16_i32 : i32 loc(#loc)
      %134 = arith.cmpi sge, %133, %c1_i32 : i32 loc(#loc)
      %135:2 = scf.if %134 -> (tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>) {
        %154 = arith.subi %133, %c1_i32 : i32 loc(#loc)
        %dv_38 = tt.splat %55 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_39 = tt.splat %56 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %do_ptrs_40:11 = scf.for %do_ptrs_86 = %c0_i32 to %154 step %c1_i32 iter_args(%arg19 = %100#0, %arg20 = %100#1, %qT_ptrs_87 = %qT_ptrs_34, %arg22 = %111, %arg23 = %113, %do_ptrs_88 = %do_ptrs_36, %arg25 = %114, %do_ptrs_89 = %do_ptrs_37, %arg27 = %115, %qT_ptrs_90 = %qT_ptrs_35, %arg29 = %112) -> (tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>>)  : i32 {
          %155 = tt.expand_dims %arg22 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %156 = tt.expand_dims %arg29 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc)
          %157 = arith.cmpi slt, %155, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %158 = arith.cmpi slt, %156, %cst_19 : tensor<1x64xi32, #linear> loc(#loc)
          %159 = tt.broadcast %157 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %160 = tt.broadcast %158 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %161 = tt.load %qT_ptrs_87, %159, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %162 = tt.load %qT_ptrs_90, %160, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %163 = tt.trans %162 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %164 = arith.cmpi slt, %arg23, %cst_3 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %165 = tt.addptr %dv_38, %arg23 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %166 = tt.load %165, %164 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %167 = arith.cmpf oeq, %166, %cst_5 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %168 = arith.select %167, %cst_7, %166 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %169 = tt.dot %43, %161, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %170 = arith.mulf %169, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
          %171 = arith.mulf %170, %cst_26 : tensor<64x64xf32, #mma> loc(#loc)
          %172 = tt.expand_dims %168 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc)
          %173 = tt.broadcast %172 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc)
          %174 = arith.subf %171, %173 : tensor<64x64xf32, #mma> loc(#loc)
          %175 = math.exp2 %174 : tensor<64x64xf32, #mma> loc(#loc)
          %176 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %177 = tt.expand_dims %arg27 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc)
          %178 = arith.cmpi slt, %176, %cst_14 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %179 = arith.cmpi slt, %177, %cst_13 : tensor<64x1xi32, #linear> loc(#loc)
          %180 = tt.broadcast %178 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %181 = tt.broadcast %179 : tensor<64x1xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc)
          %182 = tt.load %do_ptrs_88, %180, %cst_0 {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %183 = tt.load %do_ptrs_89, %181, %cst {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
          %184 = tt.trans %183 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %185 = arith.truncf %175 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %186 = ttg.convert_layout %185 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %187 = tt.dot %186, %182, %arg20, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %188 = tt.addptr %dv_39, %arg23 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %189 = tt.load %188, %164 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %190 = tt.dot %48, %184, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %191 = tt.expand_dims %189 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc)
          %192 = tt.broadcast %191 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc)
          %193 = arith.subf %190, %192 : tensor<64x64xf32, #mma> loc(#loc)
          %194 = arith.mulf %175, %193 : tensor<64x64xf32, #mma> loc(#loc)
          %195 = arith.truncf %194 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
          %196 = ttg.convert_layout %195 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc)
          %197 = tt.dot %196, %163, %arg19, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc)
          %198 = arith.divsi %do_ptrs_86, %c2_i32 : i32 loc(#loc)
          %199 = tt.addptr %101, %198 : !tt.ptr<i32>, i32 loc(#loc)
          %200 = tt.load %199 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %201 = arith.addi %198, %c1_i32 : i32 loc(#loc)
          %202 = arith.cmpi slt, %201, %105 : i32 loc(#loc)
          %203 = tt.addptr %199, %c1_i32 : !tt.ptr<i32>, i32 loc(#loc)
          %204 = tt.load %203, %202 evictionPolicy = evict_last : !tt.ptr<i32> loc(#loc)
          %205 = arith.addi %do_ptrs_86, %c1_i32 : i32 loc(#loc)
          %206 = arith.remsi %205, %c2_i32 : i32 loc(#loc)
          %207 = arith.cmpi eq, %206, %c0_i32 : i32 loc(#loc)
          %208 = arith.subi %204, %200 : i32 loc(#loc)
          %209 = arith.muli %208, %c128_i32 : i32 loc(#loc)
          %210 = arith.subi %209, %c64_i32 : i32 loc(#loc)
          %211 = arith.extui %207 : i1 to i32 loc(#loc)
          %212 = arith.muli %210, %211 : i32 loc(#loc)
          %213 = arith.subi %c1_i32, %211 : i32 loc(#loc)
          %214 = arith.muli %213, %c64_i32 : i32 loc(#loc)
          %215 = arith.addi %212, %214 : i32 loc(#loc)
          %216 = arith.muli %215, %c64_i32 : i32 loc(#loc)
          %217 = tt.splat %216 : i32 -> tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %218 = tt.splat %216 : i32 -> tensor<64x64xi32, #linear> loc(#loc)
          %219 = tt.addptr %qT_ptrs_87, %217 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %220 = tt.addptr %qT_ptrs_90, %218 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %221 = tt.addptr %do_ptrs_88, %217 : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
          %222 = tt.addptr %do_ptrs_89, %218 : tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64x64xi32, #linear> loc(#loc)
          %223 = tt.splat %215 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %224 = tt.splat %215 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %225 = tt.splat %215 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %226 = tt.splat %215 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %227 = tt.splat %215 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
          %228 = arith.addi %arg22, %223 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %229 = arith.addi %arg29, %224 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
          %230 = arith.addi %arg23, %225 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc)
          %231 = arith.addi %arg25, %226 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> loc(#loc)
          %232 = arith.addi %arg27, %227 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> loc(#loc)
          scf.yield %197, %187, %219, %228, %230, %221, %231, %222, %232, %220, %229 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>>, tensor<64x64x!tt.ptr<f16>, #linear>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> loc(#loc)
        } loc(#loc39)
        %dv_41 = tt.expand_dims %do_ptrs_40#3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_42 = tt.expand_dims %do_ptrs_40#10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x64xi32, #linear> loc(#loc29)
        %dv_43 = arith.cmpi slt, %dv_41, %cst_20 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_44 = arith.cmpi slt, %dv_42, %cst_19 : tensor<1x64xi32, #linear> loc(#loc29)
        %dv_45 = tt.broadcast %dv_43 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_46 = tt.broadcast %dv_44 : tensor<1x64xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc29)
        %dv_47 = tt.load %do_ptrs_40#2, %dv_45, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_48 = tt.load %do_ptrs_40#9, %dv_46, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc29)
        %dv_49 = tt.trans %dv_48 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_50 = arith.cmpi slt, %do_ptrs_40#4, %cst_3 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_51 = tt.addptr %dv_38, %do_ptrs_40#4 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_52 = tt.load %dv_51, %dv_50 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_53 = arith.cmpf oeq, %dv_52, %cst_5 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_54 = arith.select %dv_53, %cst_7, %dv_52 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_55 = tt.dot %43, %dv_47, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_56 = arith.mulf %dv_55, %cst_23 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_57 = arith.select %dv_30, %dv_56, %cst_25 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_58 = arith.mulf %dv_57, %cst_26 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_59 = tt.expand_dims %dv_54 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc29)
        %dv_60 = tt.broadcast %dv_59 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_61 = arith.subf %dv_58, %dv_60 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_62 = math.exp2 %dv_61 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_63 = tt.expand_dims %do_ptrs_40#6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_64 = tt.expand_dims %do_ptrs_40#8 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<64x1xi32, #linear> loc(#loc29)
        %dv_65 = arith.cmpi slt, %dv_63, %cst_14 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_66 = arith.cmpi slt, %dv_64, %cst_13 : tensor<64x1xi32, #linear> loc(#loc29)
        %dv_67 = tt.broadcast %dv_65 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_68 = tt.broadcast %dv_66 : tensor<64x1xi1, #linear> -> tensor<64x64xi1, #linear> loc(#loc29)
        %dv_69 = tt.load %do_ptrs_40#5, %dv_67, %cst_0 {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_70 = tt.load %do_ptrs_40#7, %dv_68, %cst {ttig.block_io = "row_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc29)
        %dv_71 = tt.trans %dv_70 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc29)
        %dv_72 = arith.truncf %dv_62 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc29)
        %dv_73 = ttg.convert_layout %dv_72 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
        %dv_74 = tt.dot %dv_73, %dv_69, %do_ptrs_40#1, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_75 = tt.addptr %dv_39, %do_ptrs_40#4 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_76 = tt.load %dv_75, %dv_50 : tensor<64x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc29)
        %dv_77 = tt.dot %48, %dv_71, %cst_24, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_78 = tt.expand_dims %dv_76 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc29)
        %dv_79 = tt.broadcast %dv_78 : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> loc(#loc29)
        %dv_80 = arith.subf %dv_77, %dv_79 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_81 = arith.mulf %dv_62, %dv_80 : tensor<64x64xf32, #mma> loc(#loc29)
        %dv_82 = arith.select %dv_30, %dv_81, %cst_24 : tensor<64x64xi1, #mma>, tensor<64x64xf32, #mma> loc(#loc29)
        %dv_83 = arith.truncf %dv_82 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc29)
        %dv_84 = ttg.convert_layout %dv_83 : tensor<64x64xf16, #mma> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc29)
        %dv_85 = tt.dot %dv_84, %dv_49, %do_ptrs_40#0, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc29)
        scf.yield %dv_85, %dv_74 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma> loc(#loc)
      } else {
        scf.yield %100#0, %100#1 : tensor<64x64xf32, #mma>, tensor<64x64xf32, #mma> loc(#loc)
      } loc(#loc)
      %136 = tt.splat %13 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #mma> loc(#loc)
      %137 = tt.addptr %136, %29 : tensor<64x1x!tt.ptr<f16>, #mma>, tensor<64x1xi32, #mma> loc(#loc)
      %138 = tt.broadcast %137 : tensor<64x1x!tt.ptr<f16>, #mma> -> tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
      %139 = tt.addptr %138, %41 : tensor<64x64x!tt.ptr<f16>, #mma>, tensor<64x64xi32, #mma> loc(#loc)
      %140 = arith.cmpi slt, %36, %cst_18 : tensor<1x64xi32, #mma> loc(#loc)
      %141 = tt.broadcast %140 : tensor<1x64xi1, #mma> -> tensor<64x64xi1, #mma> loc(#loc)
      %142 = arith.andi %dv_30, %141 : tensor<64x64xi1, #mma> loc(#loc)
      %143 = arith.truncf %135#1 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
      tt.store %139, %143, %142 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
      %144 = arith.mulf %135#0, %cst_23 : tensor<64x64xf32, #mma> loc(#loc)
      %145 = tt.broadcast %29 : tensor<64x1xi32, #mma> -> tensor<64x64xi32, #mma> loc(#loc)
      %146 = arith.addi %41, %145 : tensor<64x64xi32, #mma> loc(#loc)
      %147 = tt.splat %4 : i32 -> tensor<64x64xi32, #mma> loc(#loc)
      %148 = arith.addi %146, %147 : tensor<64x64xi32, #mma> loc(#loc)
      %149 = tt.splat %8 : i32 -> tensor<64x64xi32, #mma> loc(#loc)
      %150 = arith.addi %148, %149 : tensor<64x64xi32, #mma> loc(#loc)
      %151 = tt.splat %out_ptr0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
      %152 = tt.addptr %151, %150 : tensor<64x64x!tt.ptr<f16>, #mma>, tensor<64x64xi32, #mma> loc(#loc)
      %153 = arith.truncf %144 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> loc(#loc)
      tt.store %152, %153, %dv_30 {ttig.block_io = "row_major"} : tensor<64x64x!tt.ptr<f16>, #mma> loc(#loc)
    } loc(#loc)
    tt.return loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc19 = loc("dq")
#loc20 = loc("kT_ptrs")
#loc21 = loc("vT_ptrs")
#loc22 = loc("dk")
#loc23 = loc("qT_ptrs")
#loc24 = loc("do_ptrs")
#loc25 = loc(callsite(#loc at #loc19))
#loc26 = loc("offs_n2"(#loc19))
#loc27 = loc("dv"(#loc22))
#loc28 = loc("kT_ptrs"(#loc26))
#loc29 = loc(callsite(#loc at #loc27))
#loc30 = loc("offs_m1"(#loc27))
#loc31 = loc("vT_ptrs"(#loc28))
#loc32 = loc("qT_ptrs"(#loc30))
#loc33 = loc("offs_n2"(#loc31))
#loc34 = loc("do_ptrs"(#loc32))
#loc35 = loc("kT_ptrs"(#loc33))
#loc36 = loc("offs_m1"(#loc34))
#loc37 = loc("vT_ptrs"(#loc35))
#loc38 = loc("qT_ptrs"(#loc36))
#loc39 = loc("do_ptrs"(#loc38))

@etiotto
Copy link
Contributor

etiotto commented Sep 3, 2025

The flex attn backward ttgir has been simplified by these changes.

There are only two root tiling layout of the dpas and the transpose of dot of dpas.

Another major in-efficient issue on Xe-Xe3 is that the regular pointer under different layout like:

%152 = tt.load %kT_ptrs_79, %150, %cst_0 {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)
%153 = tt.load %kT_ptrs_81, %151, %cst {ttig.block_io = "column_major", ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #linear> loc(#loc)
%154 = tt.trans %153 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc)

The simplified ttgir

I checkout this branch and run the benchmark code in

flex attn bwd
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

    num_stages=1,
    num_warps=8,
    triton_meta={'signature': {'arg_Q': '*fp16', 'arg_K': '*fp16', 'arg_V': '*fp16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*fp16', 'arg_DQ': '*fp16', 'arg_DV': '*fp16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*fp16'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=56, cc={'architecture': 13136561920, 'device_id': 3034, 'driver_version': '1.6.33578+15', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51522830336, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
    inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'F16AA07D54F5BA283BB218BAD77E2CCD7BF40453AAE6EDF06604087A01BF3249', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2']},

)
@triton.jit
def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    INDEX_DTYPE : tl.constexpr = tl.int32
    Q = arg_Q
    K = arg_K
    V = arg_V
    LSE = arg_LSE
    DELTA = arg_DELTA
    DO = arg_DO
    DQ = arg_DQ
    DV = arg_DV
    KV_NUM_BLKS = arg_KV_NUM_BLKS
    KV_IDX = arg_KV_IDX
    Q_NUM_BLKS = arg_Q_NUM_BLKS
    Q_IDX = arg_Q_IDX
    FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
    FULL_KV_IDX = arg_FULL_KV_IDX
    FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
    FULL_Q_IDX = arg_FULL_Q_IDX

    # Sub notation for this kernel:
    #
    # Q: Query, K: Key, V: Value
    # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
    # DELTA: Precomputed sum(OUT*DO, axis=-1)
    # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
    # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
    # inductor codegen
    # M: Number of queries, N: Number of keys/values
    # QK_HEAD_DIM: The dimension of the query and key embeddings
    # V_HEAD_DIM: The dimension of the value embeddings
    # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
    # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
    # (Modifiable) Performance tuning options
    # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
    # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
    # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
    # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
    #
    # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
    # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
    # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
    # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
    # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
    # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
    # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.

    # The below are kernel options that can be applied for certain score_mods,
    # or involve a numerics vs. perf tradeoff
    # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
    # about 20% more numerical error, but slightly faster.

    # Define strides of inputs
    stride_qz, stride_qh, stride_qm, stride_qd = 261888, 65472, 64, 1
    stride_kz, stride_kh, stride_kn, stride_kd = 261888, 65472, 64, 1
    stride_vz, stride_vh, stride_vn, stride_vd = 261888, 65472, 64, 1
    stride_doz, stride_doh, stride_dom, stride_dod = 261888, 65472, 64, 1

    stride_dqz, stride_dqh, stride_dqm, stride_dqd = 261888, 65472, 64, 1
    stride_dvz, stride_dvh, stride_dvm, stride_dvd = 261888, 65472, 64, 1

    ZQ = 2
    HQ = 4
    HKV = 4
    Q_LEN = 1023
    ZKV = 2
    KV_LEN = 1023

    MATMUL_PRECISION = Q.dtype.element_ty

    pid = tl.program_id(0).to(INDEX_DTYPE)
    NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
    NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)

    off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
    off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
    off_zkv = off_zq % ZKV # kv batch idx

    SPARSE_Z = 1
    SPARSE_HQ = 1

    sparse_idx_z = off_zq % SPARSE_Z

    k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
    v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
    # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
    # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
    dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)

    # offset K, V, DV pointers for batch/kv-head
    K += k_adj
    V += v_adj
    DV += dv_adj

    RCP_LN2 = 1.44269504
    offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
    offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)

    if pid >= NUM_KV_BLOCKS:
        off_pid = pid - NUM_KV_BLOCKS
        # THIS BLOCK DOES DQ
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
        off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
        start_m2_block = off_pid % NUM_Q_BLOCKS
        off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
        stride_kv_num_blks_h = 8
        stride_kv_idx_h = 64
        stride_kv_idx_m = 8

        sparse_idx_hq2 = off_hq2 % SPARSE_HQ
        sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2

        sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
        sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m  # noqa: B950

        # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
        q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
        do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
        dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
        off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)

        Q2 = Q + q_adj2
        DO2 = DO + do_adj2
        # TODO: This does not work if DQ is not the same layout as Q (for example,
        # if Q is broadcasted)
        DQ2 = DQ + dq_adj2
        LSE2 = LSE + off_chz2
        DELTA2 = DELTA + off_chz2

        # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
        dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)

        start_m2 = start_m2_block * BLOCK_M2
        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)

        # load Q and do: they stay in SRAM throughout the inner loop.
        q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
        do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)

        if PRESCALE_QK:
            q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        if IS_DIVISIBLE:
            Di = tl.load(DELTA2 + offs_m2)
            lse = tl.load(LSE2 + offs_m2)
        else:
            Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
            lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
        lse = tl.where(lse == -float("inf"), 0.0, lse)
        lse = lse[:, None]

        # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # KV_IDX and KV_NUM_BLKS are always contiguous.
        kv_indices = KV_IDX + sparse_kv_idx_offset
        kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
        sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)

        offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
        dq = bwd_dq_inner(
            arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
            K, V,
            dq, q, do, Di, lse,
            off_zq, off_hq2, offs_m2, offs_n2,
            stride_kn, stride_kd, stride_vn, stride_vd,
            kv_indices, sparse_kv_num_blocks,
            MATMUL_PRECISION,
            IS_FULL_BLOCKS=False,
        )

        if HAS_FULL_BLOCKS:
            # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
            kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
            kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
            sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)

            offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
            dq = bwd_dq_inner(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                K, V,
                dq, q, do, Di, lse,
                off_zq, off_hq2, offs_m2, offs_n2,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION,
                IS_FULL_BLOCKS=True,
            )

        # Write back dQ.
        dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
        dq *= SM_SCALE
        if IS_DIVISIBLE and SAFE_HEAD_DIM:
            tl.store(dq_ptrs, dq)
        else:
            tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
    else:
        # THIS BLOCK DOES DK & DV
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)

        pid_mask = pid // SPARSE_KV_MULTIPLE

        stride_q_num_blks_h = 8
        stride_q_idx_h = 64
        stride_q_idx_n = 8


        dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
        dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)

        start_n1 = pid * BLOCK_N1
        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)

        # load K and V: they stay in SRAM throughout the inner loop.
        k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
        v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)

        if PRESCALE_QK:
            k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        for off_g in range(0, GQA_SHARED_HEADS):
            off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g

            # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
            q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
            do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
            dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
            off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)

            Q1 = Q + q_adj1
            DO1 = DO + do_adj1
            # TODO: This does not work if DQ is not the same layout as Q (for example,
            # if Q is broadcasted)
            LSE1 = LSE + off_chz1
            DELTA1 = DELTA + off_chz1

            sparse_idx_hq1 = off_hq1 % SPARSE_HQ
            sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1

            sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
            sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n  # noqa: B950

            # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Q_IDX and Q_NUM_BLKS are always contiguous.
            q_indices = Q_IDX + sparse_q_idx_offset
            q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
            sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)

            offs_m1 = q_start + tl.arange(0, BLOCK_M1)
            dk, dv = bwd_dkdv_inner(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                Q1, DO1, DELTA1, LSE1,
                dk, dv, k, v,
                off_zq, off_hq1, offs_n1, offs_m1,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION,
                IS_FULL_BLOCKS=False,
            )


            if HAS_FULL_BLOCKS:
                # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
                q_indices = FULL_Q_IDX + sparse_q_idx_offset
                q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
                sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)

                offs_m1 = q_start + tl.arange(0, BLOCK_M1)
                dk, dv = bwd_dkdv_inner(
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                    Q1, DO1, DELTA1, LSE1,
                    dk, dv, k, v,
                    off_zq, off_hq1, offs_n1, offs_m1,
                    stride_qm, stride_qd, stride_dom, stride_dod,
                    q_indices, sparse_q_num_blocks,
                    MATMUL_PRECISION,
                    IS_FULL_BLOCKS=True,
                )

        # Write back dV and dK.
        dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd

        index_n = offs_n1[:, None]
        index_k = offs_k[None, :]
        index_v = offs_v[None, :]

        if IS_DIVISIBLE and SAFE_HEAD_DIM:
            tl.store(dv_ptrs, dv)
        else:
            tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))

        dk *= SM_SCALE

        if SAFE_HEAD_DIM:
            mask = index_n < KV_LEN
        else:
            mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)

        # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
        # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
        tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
        xindex = index_k + 64*index_n + 65472*off_hkv + 261888*off_zq
        tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)

@triton.jit
def bwd_dq_inner(
        arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
        K, V,  # pointers
        dq, q, do, Di, lse,
        off_z, off_hq, offs_m2, offs_n2,
        stride_kn, stride_kd, stride_vn, stride_vd,
        kv_indices, sparse_kv_num_blocks,
        MATMUL_PRECISION,
        IS_FULL_BLOCKS,
):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    INDEX_DTYPE : tl.constexpr = tl.int32

    SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
    RCP_LN2: tl.constexpr = 1.44269504
    Q_LEN = 1023
    KV_LEN = 1023

    offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
    offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)

    kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
    vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)

    hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
    if not IS_DIVISIBLE:
        if hi >= 1:
            for start_n in range(0, hi - 1):
                dq = bwd_dq_block_mn(
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                    dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                    off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
                    stride_kn, stride_kd, stride_vn, stride_vd,
                    kv_indices, sparse_kv_num_blocks,
                    MATMUL_PRECISION, RCP_LN2,
                    IS_FULL_BLOCKS,
                )

                # Increment pointers.
                offset = get_offset_for_next_block(
                    start_n, kv_indices, sparse_kv_num_blocks,
                    SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
                )

                kT_ptrs += offset * stride_kn
                vT_ptrs += offset * stride_vn

                offs_n2 += offset

            dq = bwd_dq_block_mn(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
            )
    else:
        for start_n in range(0, hi):
            dq = bwd_dq_block_mn(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                IS_FULL_BLOCKS,
            )

            # Increment pointers.
            offset = get_offset_for_next_block(
                start_n, kv_indices, sparse_kv_num_blocks,
                SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
            )

            kT_ptrs += offset * stride_kn
            vT_ptrs += offset * stride_vn

            offs_n2 += offset

    return dq


@triton.jit
def bwd_dq_block_mn(
        arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
        dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
        off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
        stride_kn, stride_kd, stride_vn, stride_vd,
        kv_indices, sparse_kv_num_blocks,
        MATMUL_PRECISION, RCP_LN2,
        IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    INDEX_DTYPE : tl.constexpr = tl.int32


    # NB reversed order to since K is transposed
    kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
    qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
    if not PRESCALE_QK:
        qk *= SM_SCALE
    # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
    pre_mod_scores = qk
    n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
    # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
    # that the M reads out of bounds prior to the last loop
    m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)

    tmp0 = (qk)
    post_mod_scores = tmp0


    if CHECK_BLOCK_BOUNDARY:
        # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
        post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))

    if not IS_FULL_BLOCKS:
        tmp1 = (m)
        tmp2 = tl.load(in_ptr16 + tmp1)
        tmp3 = (n)
        tmp4 = tl.load(in_ptr16 + tmp3)
        tmp5 = tmp2 == tmp4
        mask_mod_output = tmp5


        if CHECK_BLOCK_BOUNDARY:
            mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
        # apply mask for partial masked block
        post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if not PRESCALE_QK:
        post_mod_scores *= RCP_LN2
    p = tl.math.exp2(post_mod_scores - lse)
    # Compute dP and dS.
    # NB reversed order to since V is transposed
    vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)

    dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
    ds = p * (dp - Di[:, None])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    tmp6 = (ds)
    grad_scores = tmp6

    if CHECK_BLOCK_BOUNDARY:
        grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)

    # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
    if WRITE_DQ:
        scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    ds = grad_scores

    if not IS_FULL_BLOCKS:
        if CHECK_BLOCK_BOUNDARY:
            mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
        # (grads) apply mask for partially unmasked block
        ds = tl.where(mask_mod_output, ds, 0.0)
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    ds = ds.to(MATMUL_PRECISION)
    # Compute dQ.
    # dq += tl.dot(ds, kT, input_precision=FLOAT32_PRECISION)
    dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)

    return dq


@triton.jit
def bwd_dkdv_inner(
        arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
        Q, DO, DELTA, LSE, # pointers
        dk, dv, k, v,
        off_z, off_hq, offs_n1, offs_m1,
        stride_qm, stride_qd, stride_dom, stride_dod,
        q_indices, sparse_q_num_blocks,
        MATMUL_PRECISION,
        IS_FULL_BLOCKS,
):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    INDEX_DTYPE : tl.constexpr = tl.int32

    SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
    RCP_LN2: tl.constexpr = 1.44269504
    Q_LEN = 1023
    KV_LEN = 1023

    offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
    offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)

    qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
    do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))

    if not IS_DIVISIBLE:
        if hi >= 1:
            for start_m in range(0, hi - 1):
                dk, dv = bwd_dkdv_block_mn(
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                    dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                    off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
                    stride_qm, stride_qd, stride_dom, stride_dod,
                    q_indices, sparse_q_num_blocks,
                    MATMUL_PRECISION, RCP_LN2,
                    IS_FULL_BLOCKS,
                )
                # Increment pointers.
                offset = get_offset_for_next_block(
                    start_m, q_indices, sparse_q_num_blocks,
                    SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
                )

                qT_ptrs += offset * stride_qm
                do_ptrs += offset * stride_dom

                offs_m1 += offset

            dk, dv = bwd_dkdv_block_mn(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
            )
    else:
        for start_m in range(0, hi):
            dk, dv = bwd_dkdv_block_mn(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
                dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                IS_FULL_BLOCKS,
            )
            # Increment pointers.
            offset = get_offset_for_next_block(
                start_m, q_indices, sparse_q_num_blocks,
                SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
            )

            qT_ptrs += offset * stride_qm
            do_ptrs += offset * stride_dom

            offs_m1 += offset

    return dk, dv


@triton.jit
def bwd_dkdv_block_mn(
        arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
        dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
        off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
        stride_qm, stride_qd, stride_dom, stride_dod,
        q_indices, sparse_q_num_blocks,
        MATMUL_PRECISION, RCP_LN2,
        IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    INDEX_DTYPE : tl.constexpr = tl.int32


    # NB reversed order since Q is transposed
    qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
    # Load LSE before computing qk to reduce pipeline stall.
    if IS_DIVISIBLE:
        lse = tl.load(LSE + offs_m1)
    else:
        lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
    lse = tl.where(lse == -float("inf"), 0.0, lse)
    qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
    if not PRESCALE_QK:
        qkT *= SM_SCALE
    # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
    m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
    # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
    # that the n reads out of bounds prior to the last loop
    n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)

    pre_mod_scores = qkT
    tmp7 = (qkT)
    post_mod_scores = tmp7


    if CHECK_BLOCK_BOUNDARY:
        # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
        post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))

    if not IS_FULL_BLOCKS:
        tmp8 = (m)
        tmp9 = tl.load(in_ptr16 + tmp8)
        tmp10 = (n)
        tmp11 = tl.load(in_ptr16 + tmp10)
        tmp12 = tmp9 == tmp11
        mask_mod_output = tmp12

        if CHECK_BLOCK_BOUNDARY:
            mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
        # (grads) apply mask for fully masked block
        post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if not PRESCALE_QK:
        post_mod_scores *= RCP_LN2
    pT = tl.math.exp2(post_mod_scores - lse[None, :])
    do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
    # Compute dV.
    ppT = pT
    dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
    if IS_DIVISIBLE:
        Di = tl.load(DELTA + offs_m1)
    else:
        Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
    # Compute dP and dS.
    dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
    # dpT = tl.dot(v, do, input_precision=FLOAT32_PRECISION)
    dsT = pT * (dpT - Di[None, :])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    tmp13 = (dsT)
    grad_scores = tmp13


    # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
    if not WRITE_DQ:
        idx_b = off_z
        idx_h = off_hq
        idx_m = m
        idx_n = n
        scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    if CHECK_BLOCK_BOUNDARY:
        grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)

    dsT = grad_scores
    if not IS_FULL_BLOCKS:
        if CHECK_BLOCK_BOUNDARY:
            mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
        # (grads) apply mask for partially unmasked block
        dsT = tl.where(mask_mod_output, dsT, 0.0)
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
    # dk += tl.dot(dsT.to(MATMUL_PRECISION), qT, input_precision=FLOAT32_PRECISION)

    return dk, dv

# Utility triton funcs
@triton.jit
def get_offset_for_next_block(
        loop_iter, col_indices, total_blocks,
        SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
        BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
    if BLOCKS_ARE_CONTIGUOUS:
        return BLOCK
    cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
    cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
    next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
    needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
    jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
    offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
    return offset

@triton.jit
def get_bounded_indices(indices, max_len=None):
    return indices % max_len if max_len is not None else indices

@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
    if IS_DIVISIBLE and SAFE_HEAD_DIM:
        return tl.load(block_ptr)
    elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
        return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
    elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
        return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
    else:
        return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")

@triton.jit
def load_checked_2d(
        ptr,
        offs_m,
        offs_n,
        stride_m,
        stride_n,
        IS_DIVISIBLE_M: tl.constexpr,
        IS_DIVISIBLE_N: tl.constexpr,
        M_LEN: tl.constexpr,
        N_DIM: tl.constexpr,
):
    # Calculate final pointer if strides are provided
    if stride_m is not None and stride_n is not None:
        ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n

    # Handle all masking cases
    if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
        return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
    elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
        return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
    elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
        return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
    else:  # Both divisible
        return tl.load(ptr)

import time

start_time = time.time_ns()
# with torch.profiler.profile() as prof:
triton_tem_fused_zeros_1.precompile()
# prof.export_chrome_trace("trace_compile.json")
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
precompile_time_taken_ns = time.time_ns() - start_time

print("johnlu done, compile time:", precompile_time_taken_ns / 1e9, "s")

But the TTGI I get is not the same as the one you mentioned, the IR around the tt.trans operation is:

Snippet from TTGIR
        %vT_ptrs_118:5 = scf.for %vT_ptrs_148 = %c0_i32 to %dq_110 step %c1_i32 iter_args(%arg19 = %cst_17, %kT_ptrs_149 = %kT_ptrs_79, %offs_n2_150 = %offs_n2_70, %vT_ptrs_151 = %vT_ptrs_82, %offs_n2_152 = %offs_n2_71) -> (tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr<f16>, #blocked>, tensor
<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>)  : i32 {
          %kT_153 = tt.expand_dims %offs_n2_150 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc778)
          %kT_154 = tt.expand_dims %offs_n2_152 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc778)
          %kT_155 = arith.cmpi slt, %kT_153, %cst_10 : tensor<1x64xi32, #blocked> loc(#loc779)
          %kT_156 = tt.broadcast %kT_155 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc780)
          %kT_157 = tt.load %kT_ptrs_149, %kT_156, %cst_15 {ttig.one_matrix_per_load} : tensor<64x64x!tt.ptr<f16>, #blocked> loc(#loc780)
          %kT_158 = ttg.local_alloc %kT_157 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared1, #smem> loc(#loc780)
          %dq_159 = ttg.convert_layout %kT_157 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #linear> loc(#loc657)
          %dq_160 = tt.trans %dq_159 {order = array<i32: 1, 0>} : tensor<64x64xf16, #linear> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc657)
          %kT_161 = ttg.local_load %kT_158 : !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc780)
          %qk_162 = tt.dot %q_113, %kT_161, %cst_17, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x64xf32, #mma> loc(#loc658)

The IR still contains the convert_layout operations and there are other ops (e.g. ttg.local_alloc) which do not appear for you. How did you compile the benchmark? Does this PR contains all of your code ?

@chengjunlu chengjunlu force-pushed the chengjun/enhance_remove_layout branch from 5828b55 to d85a2ba Compare September 4, 2025 04:48
@chengjunlu
Copy link
Contributor Author

@etiotto Added the missed changes for debug the backward.

chengjunlu and others added 3 commits September 4, 2025 12:40
…d values with different layout in scf.for.

Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
@etiotto etiotto requested a review from Copilot September 24, 2025 19:39
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Signed-off-by: Ettore Tiotto <ettore.tiotto@intel.com>
@etiotto etiotto requested a review from Copilot September 30, 2025 19:56
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines 1502 to 1504
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr);
Copy link
Preview

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Adding a nullptr parameter without documentation or context makes the API unclear. Consider adding a comment explaining what this parameter represents or using a named constant instead of nullptr.

Suggested change
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr);
// No filter function is provided, so we pass kNoRematerializationFilter (nullptr).
static constexpr auto kNoRematerializationFilter = nullptr;
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, kNoRematerializationFilter);

Copilot uses AI. Check for mistakes.

Signed-off-by: Ettore Tiotto <ettore.tiotto@intel.com>
@etiotto etiotto marked this pull request as ready for review October 1, 2025 16:39
@etiotto etiotto requested review from anmyachev and removed request for etiotto October 1, 2025 16:40
Signed-off-by: Ettore Tiotto <ettore.tiotto@intel.com>
@etiotto
Copy link
Contributor

etiotto commented Oct 1, 2025

Did a performance run on microbenchmarks. No degradations on PVC and BMG. Potential improvement for gemm-streamk on BMG:

image

@etiotto etiotto changed the title [Draft] [BACKEND] Enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for. [BACKEND] Enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for. Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FlexAttn] [BACKWARD] Improve the remove layout pass for flex attn backward kernel [BACKEND] Enhance the remove layout for Intel GPU
2 participants