From d80428f7d0cde573c7d6ddbef0f87fa2fa6bd118 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Fri, 1 Aug 2025 15:25:33 +0800 Subject: [PATCH 1/2] make append_attn supports mask_offset --- custom_ops/gpu_ops/append_attention.cu | 10 ++++++++++ .../append_attn/append_attention_c16_impl.cuh | 19 ++++++++++++++----- .../append_attn/append_attention_c4_impl.cuh | 19 ++++++++++++++----- .../append_attn/append_attention_c8_impl.cuh | 19 ++++++++++++++----- .../append_attn/append_attention_func.cuh | 16 +++++++++++----- custom_ops/gpu_ops/append_attn/utils.cuh | 4 ++++ custom_ops/gpu_ops/cpp_extensions.cc | 1 + .../layers/attention/append_attn_backend.py | 2 ++ .../layers/attention/ops/append_attention.py | 2 ++ 9 files changed, 72 insertions(+), 20 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 2ba7555e7f..ceb7e8fe26 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -72,6 +72,7 @@ std::vector AppendAttentionKernel( const paddle::optional& cache_v_zp, const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -429,6 +430,7 @@ std::vector AppendAttention( const paddle::optional& cache_v_zp, const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, const paddle::optional& kv_signal_data, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -464,6 +466,10 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + auto dispatch_by_template = [&](auto temp_args) -> std::vector { return AppendAttentionKernel::value>( meta_data, @@ -499,6 +505,7 @@ std::vector AppendAttention( cache_v_zp, out_linear_shifts, out_linear_smooths, + mask_offset, kv_signal_data, cache_quant_type_str, use_neox_rotary_style, @@ -576,6 +583,7 @@ std::vector> AppendAttentionInferShape( const paddle::optional>& cache_v_zp_shape, const paddle::optional>& out_linear_shifts_shape, const paddle::optional>& out_linear_smooths_shape, + const paddle::optional>& mask_offset_shape, const paddle::optional>& kv_signal_data_shape, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -636,6 +644,7 @@ std::vector AppendAttentionInferDtype( const paddle::optional& cache_v_zp_dtype, const paddle::optional& out_linear_shifts_dtype, const paddle::optional& out_linear_smooths_dtype, + const paddle::optional& mask_offset_dtype, const paddle::optional& kv_signal_data_dtype, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -714,6 +723,7 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("cache_v_zp"), paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_smooths"), + paddle::Optional("mask_offset"), paddle::Optional("kv_signal_data")}) .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) .SetInplaceMap({{"key_cache", "key_cache_out"}, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index b7d8441c68..19f156fc20 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel( } else { o_base_ptr_int8 = out + o_offset; } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } // update m,d @@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( @@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } // update m,d @@ -882,6 +887,7 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -939,6 +945,7 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1104,6 +1111,7 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1172,6 +1180,7 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 9f003af88b..9e6fd356b5 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel( } else { o_base_ptr_int8 = out + o_offset; } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = @@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } update_mdo_states( @@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); uint32_t k_smem_offset_r = @@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } update_mdo_states( @@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1334,6 +1341,7 @@ void MultiQueryAppendC4Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1410,6 +1418,7 @@ void MultiQueryAppendC4Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 3b72597e02..d24137b94a 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel( } else { o_base_ptr_int8 = out + o_offset; } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = @@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } // update m,d @@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : chunk_len) / + : mask_offset ? 0 : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); uint32_t k_smem_offset_r = @@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( q_len, kv_len, chunk_end, - s_frag); + s_frag, + mask_offset_this_seq); } // update m,d @@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + meta_data.mask_offset, max_seq_len, max_dec_len, max_block_num_per_seq, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 8b6802d27d..060340395f 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, - float (*s_frag)[num_frags_z][8]) { + float (*s_frag)[num_frags_z][8], + const int *mask_offset = nullptr) { const uint32_t tx = threadIdx.x; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, group_size, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; - const bool out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); + bool out_of_boundary; + if (mask_offset) { + out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true; + } else { + out_of_boundary = + (causal + ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } if constexpr (std::is_same::value) { s_frag[fx][fz][reg_id] = out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 05f500126c..a05b6945c8 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -27,6 +27,7 @@ struct AppendAttnMetaData { int head_dims; int head_dims_v; int max_blocks_per_seq; + const int *mask_offset = nullptr; }; __forceinline__ __host__ __device__ int div_up(int a, int b) { @@ -474,6 +475,9 @@ __forceinline__ __host__ __device__ void vec_cast( if (causal) { \ constexpr bool CAUSAL = true; \ __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ } #define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \ diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index b4d7b952d5..a4b0c8f2b0 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -77,6 +77,7 @@ std::vector AppendAttention( const paddle::optional &cache_v_zp, const paddle::optional &out_linear_shifts, const paddle::optional &out_linear_smooths, + const paddle::optional &mask_offset, const paddle::optional &kv_signal_data, const std::string &compute_dtype, const std::string &cache_quant_type_str, const bool use_neox_rotary_style, const bool rope_3d, diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index cffc4adf72..11fa07f27f 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -62,6 +62,7 @@ class AppendAttentionMetadata(AttentionMetadata): block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None + mask_offset: Optional[paddle.Tensor] = None _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -261,6 +262,7 @@ def forward_mixed( getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, + metadata.mask_offset, metadata.kv_signal_data_list[layer.layer_id], metadata._fuse_kernel_compute_dtype, getattr(layer, "cache_quant_type_str", "none"), diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index de538ad695..dd5b00f691 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -59,6 +59,7 @@ def append_attention( cache_v_zp: Optional[paddle.Tensor] = None, linear_shift: Optional[paddle.Tensor] = None, linear_smooth: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, kv_signal_data: Optional[paddle.Tensor] = None, compute_type: str = "bf16", cache_quant_type: str = "none", @@ -113,6 +114,7 @@ def append_attention( cache_v_zp, linear_shift, linear_smooth, + mask_offset, kv_signal_data, compute_type, cache_quant_type, From 8dd52a673b2b674aac6841c89c3f15136e02d502 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Mon, 11 Aug 2025 17:46:10 +0800 Subject: [PATCH 2/2] add unittest --- test/layers/test_append_attention.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/layers/test_append_attention.py b/test/layers/test_append_attention.py index 1c2ac0bbf6..e3e4de158c 100644 --- a/test/layers/test_append_attention.py +++ b/test/layers/test_append_attention.py @@ -349,6 +349,7 @@ def setUp(self): self.rope_theta = 10000 self.dtype = "float16" self.use_qk_norm = True + self.use_mask_offset = False self.init_tensor() def init_tensor(self): @@ -404,6 +405,12 @@ def init_tensor(self): self.cu_seqlens_k, ) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time) self.token_num = self.padding_offset.shape[0] + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32") + for i in range(self.batch_size): + for j in range(self.seq_len): + self.mask_offset[i * self.seq_len + j] = j def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None): paddle.disable_static() @@ -505,6 +512,7 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask None, # cache_v_zp None, # linear_shift None, # linear_smooth + self.mask_offset, # mask_offset None, # kv_signal_data q_norm_weight, # q_norm_weight k_norm_weight, # k_norm_weight @@ -560,6 +568,8 @@ def test_all(self): # encoder # self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len) self.seq_lens_this_time = self.seq_lens_encoder + if self.use_mask_offset: + print("encoder mask_offset: ", self.mask_offset) self.cmp_append_attention(attn_mask=self.attention_mask) naive_cache_k, naive_cache_v = block_cache_to_naive_cache( self.cache_k, @@ -590,6 +600,11 @@ def test_all(self): self.cu_seqlens_q, self.cu_seqlens_k, ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i] = self.seq_lens_dec[i] + print("decoder mask_offset: ", self.mask_offset) self.cmp_append_attention(naive_cache_k, naive_cache_v, None) @@ -614,6 +629,7 @@ def setUp(self): self.rope_theta = 10000 self.dtype = "float16" self.use_qk_norm = False + self.use_mask_offset = True self.init_tensor()