Skip to content

make append_attn supports mask_offset #3138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
Expand Down Expand Up @@ -441,6 +442,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
Expand Down Expand Up @@ -479,6 +481,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];

if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}

auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
Expand Down Expand Up @@ -514,6 +520,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
Expand Down Expand Up @@ -594,6 +601,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
Expand Down Expand Up @@ -657,6 +665,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
Expand Down Expand Up @@ -738,6 +747,7 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
Expand Down
19 changes: 14 additions & 5 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
Expand Down Expand Up @@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

// update m,d
Expand Down Expand Up @@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}

const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

// update m,d
Expand Down Expand Up @@ -882,6 +887,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -939,6 +945,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1103,6 +1110,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1171,6 +1179,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down
19 changes: 14 additions & 5 deletions custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);

uint32_t k_smem_offset_r =
Expand Down Expand Up @@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
Expand Down Expand Up @@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}

const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);

uint32_t k_smem_offset_r =
Expand Down Expand Up @@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
Expand Down Expand Up @@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1335,6 +1342,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1411,6 +1419,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down
19 changes: 14 additions & 5 deletions custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);

uint32_t k_smem_offset_r =
Expand Down Expand Up @@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

// update m,d
Expand Down Expand Up @@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}

const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
Expand Down Expand Up @@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);

uint32_t k_smem_offset_r =
Expand Down Expand Up @@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}

// update m,d
Expand Down Expand Up @@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down
16 changes: 11 additions & 5 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
float (*s_frag)[num_frags_z][8]) {
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
Expand All @@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
Expand Down
Loading
Loading