diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index 16f50667422142..496b476ae30d42 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -276,9 +276,6 @@ else() -DCMAKE_JOB_POOLS:STRING=compile=${FA_JOB_POOLS_COMPILE} -DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN} -DWITH_FLASHATTN_V3=${WITH_FLASHATTN_V3} - -DDISABLE_FP8=ON # umiswing: disable FP8, SM8x and PACKGQA on FA3 - -DDISABLE_SM8X=ON - -DDISABLE_PACKGQA=ON -DSKIP_BUILD_FA=${SKIP_BUILD_FA} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS diff --git a/paddle/phi/backends/dynload/flashattnv3.h b/paddle/phi/backends/dynload/flashattnv3.h index 899b30c270ad53..d3dd3d3f7eb252 100644 --- a/paddle/phi/backends/dynload/flashattnv3.h +++ b/paddle/phi/backends/dynload/flashattnv3.h @@ -24,7 +24,7 @@ namespace phi { namespace dynload { extern std::once_flag flashattnv3_dso_flag; -extern void *flashattnv3_dso_handle; +extern void* flashattnv3_dso_handle; #define DYNAMIC_LOAD_FLASHATTN_V3_WRAP(__name) \ struct DynLoad__##__name { \ @@ -34,7 +34,7 @@ extern void *flashattnv3_dso_handle; std::call_once(flashattnv3_dso_flag, []() { \ flashattnv3_dso_handle = phi::dynload::GetFlashAttnV3DsoHandle(); \ }); \ - static void *p_##__name = dlsym(flashattnv3_dso_handle, #__name); \ + static void* p_##__name = dlsym(flashattnv3_dso_handle, #__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ }; \ @@ -45,223 +45,12 @@ extern void *flashattnv3_dso_handle; #ifdef PADDLE_WITH_CUDA #define FLASHATTN_V3_ROUTINE_EACH(__macro) \ - __macro(fa3_create_fwd_params_handle); \ - __macro(fa3_clear_fwd_params_handle); \ - __macro(fa3_destroy_fwd_params_handle); \ - __macro(fa3_create_bwd_params_handle); \ - __macro(fa3_clear_bwd_params_handle); \ - __macro(fa3_destroy_bwd_params_handle); \ - __macro(fa3_cast_to_fwd_params_handle); \ - __macro(fa3_run_mha_fwd_combine); \ - __macro(fa3_run_mha_fwd); \ - __macro(fa3_run_mha_bwd); \ - __macro(fa3_get_pagedkv_tma); \ - __macro(fa3_get_pack_gqa); \ - __macro(fa3_get_num_splits); - -FLASHATTN_V3_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP) - -#define FLASHATTN_V3_HANDLE_ROUTINE(member) \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_fwd_params_get_##member); \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_fwd_params_set_##member); \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_get_##member); \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_set_##member); - -// The QKV matrices. -FLASHATTN_V3_HANDLE_ROUTINE(q_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(k_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(v_ptr) - -// The stride between rows of the Q, K and V matrices. -FLASHATTN_V3_HANDLE_ROUTINE(q_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(k_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(q_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(k_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(q_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(k_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_dim_stride) - -// The number of heads. -FLASHATTN_V3_HANDLE_ROUTINE(h) -FLASHATTN_V3_HANDLE_ROUTINE(h_k) - -// The O matrix (output). -FLASHATTN_V3_HANDLE_ROUTINE(o_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(oaccum_ptr) - -// The stride between rows of O. -FLASHATTN_V3_HANDLE_ROUTINE(o_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(o_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(o_head_stride) - -// The pointer to the softmax sum. -FLASHATTN_V3_HANDLE_ROUTINE(softmax_lse_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(softmax_lseaccum_ptr) - -// For FP8 scaling -FLASHATTN_V3_HANDLE_ROUTINE(q_descale_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(k_descale_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(v_descale_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(q_descale_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(q_descale_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(k_descale_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(k_descale_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_descale_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(v_descale_head_stride) - -// The dimensions. -FLASHATTN_V3_HANDLE_ROUTINE(b) -FLASHATTN_V3_HANDLE_ROUTINE(seqlen_q) -FLASHATTN_V3_HANDLE_ROUTINE(seqlen_k) -FLASHATTN_V3_HANDLE_ROUTINE(seqlen_knew) -FLASHATTN_V3_HANDLE_ROUTINE(d) -FLASHATTN_V3_HANDLE_ROUTINE(seqlen_q_rounded) -FLASHATTN_V3_HANDLE_ROUTINE(seqlen_k_rounded) -FLASHATTN_V3_HANDLE_ROUTINE(d_rounded) -FLASHATTN_V3_HANDLE_ROUTINE(rotary_dim) -FLASHATTN_V3_HANDLE_ROUTINE(total_q) -FLASHATTN_V3_HANDLE_ROUTINE(total_k) -FLASHATTN_V3_HANDLE_ROUTINE(total_knew) -FLASHATTN_V3_HANDLE_ROUTINE(b_k) -FLASHATTN_V3_HANDLE_ROUTINE(dv) -FLASHATTN_V3_HANDLE_ROUTINE(dv_rounded) - -// The scaling factors for the kernel. -FLASHATTN_V3_HANDLE_ROUTINE(scale_softmax) -FLASHATTN_V3_HANDLE_ROUTINE(softcap) - -// array of length b+1 holding starting offset of each sequence. -FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_q) -FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_k) -FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_knew) -FLASHATTN_V3_HANDLE_ROUTINE(leftpad_k) - -// If provided, the actual length of each q/k sequence. -FLASHATTN_V3_HANDLE_ROUTINE(seqused_q) -FLASHATTN_V3_HANDLE_ROUTINE(seqused_k) - -// The stride between rows of Oaccum. -FLASHATTN_V3_HANDLE_ROUTINE(oaccum_split_stride) -FLASHATTN_V3_HANDLE_ROUTINE(oaccum_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(oaccum_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(oaccum_head_stride) - -// The stride between rows of LSEaccum. -FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_split_stride) -FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_head_stride) - -// The K_new and V_new matrices. -FLASHATTN_V3_HANDLE_ROUTINE(knew_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(vnew_ptr) - -// The stride between rows of the Q, K and V matrices. -FLASHATTN_V3_HANDLE_ROUTINE(knew_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(vnew_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(knew_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(vnew_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(knew_head_stride) -FLASHATTN_V3_HANDLE_ROUTINE(vnew_head_stride) - -FLASHATTN_V3_HANDLE_ROUTINE(qv_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(qv_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(qv_row_stride) -FLASHATTN_V3_HANDLE_ROUTINE(qv_head_stride) - -// The cos and sin matrices for rotary embedding. -FLASHATTN_V3_HANDLE_ROUTINE(rotary_cos_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(rotary_sin_ptr) - -// The indices to index into the KV cache. -FLASHATTN_V3_HANDLE_ROUTINE(kv_batch_idx) - -// Paged KV cache -FLASHATTN_V3_HANDLE_ROUTINE(page_table) -FLASHATTN_V3_HANDLE_ROUTINE(page_table_batch_stride) -FLASHATTN_V3_HANDLE_ROUTINE(page_size) -FLASHATTN_V3_HANDLE_ROUTINE(num_pages) -FLASHATTN_V3_HANDLE_ROUTINE(pagedkv_tma) - -// The dropout probability (probability of keeping an activation). -FLASHATTN_V3_HANDLE_ROUTINE(p_dropout) -FLASHATTN_V3_HANDLE_ROUTINE(p_dropout_in_uint8_t) - -// Scale factor of 1 / (1 - p_dropout). -FLASHATTN_V3_HANDLE_ROUTINE(rp_dropout) - -// Local window size -FLASHATTN_V3_HANDLE_ROUTINE(window_size_left) -FLASHATTN_V3_HANDLE_ROUTINE(window_size_right) - -// Pointer to the RNG seed (idx 0) and offset (idx 1). -FLASHATTN_V3_HANDLE_ROUTINE(rng_state) - -FLASHATTN_V3_HANDLE_ROUTINE(is_bf16) -FLASHATTN_V3_HANDLE_ROUTINE(is_fp32) -FLASHATTN_V3_HANDLE_ROUTINE(is_e4m3) -FLASHATTN_V3_HANDLE_ROUTINE(is_causal) -FLASHATTN_V3_HANDLE_ROUTINE(is_local) - -FLASHATTN_V3_HANDLE_ROUTINE(is_rotary_interleaved) - -FLASHATTN_V3_HANDLE_ROUTINE(num_splits) // For split-KV version -FLASHATTN_V3_HANDLE_ROUTINE(pack_gqa) - -FLASHATTN_V3_HANDLE_ROUTINE(tile_count_semaphore) -FLASHATTN_V3_HANDLE_ROUTINE(num_splits_dynamic_ptr) -FLASHATTN_V3_HANDLE_ROUTINE(skip_scheduler_metadata_computation) - -FLASHATTN_V3_HANDLE_ROUTINE(arch) -FLASHATTN_V3_HANDLE_ROUTINE(num_sm) - -#define FLASHATTN_V3_BWD_HANDLE_ROUTINE(type, member) \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_get_##member); \ - DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_set_##member); - -// The dO and dQKV matrices. -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, do_ptr) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dq_ptr) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dk_ptr) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dv_ptr) - -// To accumulate dQ -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dq_accum_ptr) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dk_accum_ptr) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dv_accum_ptr) - -// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q -// dimension void *__restrict__ dk_accum_ptr; void *__restrict__ -// dv_accum_ptr; - -// The stride between rows of the dO, dQ, dK and dV matrices. -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_batch_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_row_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_head_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_batch_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_batch_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_batch_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_row_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_row_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_row_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_head_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_head_stride) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_head_stride) - -// The pointer to the softmax d sum. -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dsoftmax_sum) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, softmax_lse_log2_ptr) - -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dq_semaphore) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dk_semaphore) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dv_semaphore) - -FLASHATTN_V3_BWD_HANDLE_ROUTINE(bool, deterministic) -FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_accum_split_stride) + __macro(flash_attn_v3_fwd); \ + __macro(flash_attn_v3_bwd); #endif +FLASHATTN_V3_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP); + #undef DYNAMIC_LOAD_FLASHATTN_V3_WRAP } // namespace dynload diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a7d368ea869b22..914b4a7a960395 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -348,23 +348,6 @@ void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) { } } -void FlashAttnV3GradInferMeta(const MetaTensor& q, - const MetaTensor& k, - const MetaTensor& v, - MetaTensor* dq, - MetaTensor* dk, - MetaTensor* dv) { - if (dq) { - dq->share_meta(q); - } - if (dk) { - dk->share_meta(k); - } - if (dv) { - dv->share_meta(v); - } -} - void Flatten2GradInferMeta(const MetaTensor& x, const MetaTensor& x_shape, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bca0c6f53906f9..c497979b2bca9b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -248,13 +248,6 @@ void FlashAttnGradInferMeta(const MetaTensor& q, void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq); -void FlashAttnV3GradInferMeta(const MetaTensor& q, - const MetaTensor& k, - const MetaTensor& v, - MetaTensor* dq, - MetaTensor* dk, - MetaTensor* dv); - void Flatten2GradInferMeta(const MetaTensor& x, const MetaTensor& x_shape, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a66797a4d22437..4f2c639ae37e54 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -726,37 +726,6 @@ void CalcReducedAttnScoresInferMeta(const MetaTensor& q, reduced_scores->set_dims({batch_size, num_heads, 1, seqlen_k}); } -void FlashAttnV3InferMeta(const MetaTensor& q, - const MetaTensor& k, - const MetaTensor& v, - MetaTensor* out, - MetaTensor* softmax_lse) { - // TODO(umiswing): support varlen - constexpr bool is_varlen_q = false; - auto const sizes = q.dims(); - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - int num_heads = q.dims()[q.dims().size() - 2]; - int const head_size_v = v.dims()[v.dims().size() - 1]; - auto q_type = q.dtype(); - auto out_type = - q_type == phi::DataType::FLOAT8_E4M3FN ? phi::DataType::BFLOAT16 : q_type; - if (!is_varlen_q) { - out->set_dims({batch_size, seqlen_q, num_heads, head_size_v}); - } else { - // TODO(umiswing): support varlen - } - - out->set_dtype(out_type); - - if (!is_varlen_q) { - softmax_lse->set_dims({batch_size, num_heads, seqlen_q}); - } else { - // TODO(umiswing): support varlen - } - softmax_lse->set_dtype(phi::DataType::FLOAT32); -} - void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& end, const MetaTensor& step, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 14dd2685949573..04a3cd7736c850 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -162,12 +162,6 @@ void CalcReducedAttnScoresInferMeta(const MetaTensor& q, const MetaTensor& softmax_lse, MetaTensor* reduced_scores); -void FlashAttnV3InferMeta(const MetaTensor& q, - const MetaTensor& k, - const MetaTensor& v, - MetaTensor* out, - MetaTensor* softmax_lse); - void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu index cf3f788e19a1fb..af1df2c3a6ffd6 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" namespace phi { @@ -34,6 +35,21 @@ void FusedGemmEpilogueGradKernel( DenseTensor* x_grad, DenseTensor* y_grad, DenseTensor* bias_grad) { + if (x.numel() == 0) { + dev_ctx.template Alloc(y_grad); + phi::FullKernel( + dev_ctx, common::vectorize(y.dims()), 0.0, y.dtype(), y_grad); + + if (bias_grad) { + dev_ctx.template Alloc(bias_grad); + phi::FullKernel(dev_ctx, + common::vectorize(bias_grad->dims()), + 0.0, + bias_grad->dtype(), + bias_grad); + } + return; + } #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060 PADDLE_THROW(common::errors::Unimplemented( "The fused_gemm_epilogue operator only support CUDA 11.6 " diff --git a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu index e22e627059a11b..a1f2e6cd035526 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu @@ -74,6 +74,10 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx, const std::string& activation, DenseTensor* out, DenseTensor* reserve_space) { + if (out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060 PADDLE_THROW(common::errors::Unimplemented( "The fused_gemm_epilogue operator only support CUDA 11.6 " diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index d55283e1fc10fb..c611b134aa7f16 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -26,9 +26,6 @@ #include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" -#ifdef PADDLE_WITH_FLASHATTN_V3 -#include "paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.h" -#endif COMMON_DECLARE_bool(cudnn_deterministic); COMMON_DECLARE_int32(flash_attn_version); @@ -626,7 +623,7 @@ void FlashAttnGradBaseKernel( const float softmax_unscale = std::sqrt(head_size); int version = - FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic && + FLAGS_flash_attn_version == 3 && (head_size == 64 || head_size == 128 || head_size == 256) ? FLAGS_flash_attn_version : 2; @@ -792,22 +789,73 @@ void FlashAttnGradBaseKernel( "FlashMask or Dense Mask is unsupported in FlashAttention V3")); } - FlashAttnV3GradKernel(ctx, - q, - k, - v, - out, - softmax_lse, - dout, - params.softmax_scale, - causal, - -1, // window_size_left - -1, // window_size_right - 0.f, // softcap - 0, // sm_margin - dq, - dk, - dv); + bool deterministic = FLAGS_cudnn_deterministic ? true : false; + succ = phi::dynload::flash_attn_v3_bwd( + dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.softmax_lse_log2.data(), + params.rng_state.data(), + kdq->data(), + kdk->data(), + kdv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.is_bf16, + num_splits, + deterministic, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + is_flashmask ? downstart_row_indices_data : nullptr, + is_flashmask ? downend_row_indices_data : nullptr, + is_flashmask ? upend_row_indices_data : nullptr, + is_flashmask ? upstart_row_indices_data : nullptr, + is_flashmask ? flashmask_maxmin.data() : nullptr, + is_flashmask ? params.startend_row_indices_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out.strides()[0], + out.strides()[1], + out.strides()[2], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + kdq->strides()[1], + kdk->strides()[1], + kdv->strides()[1], + kdq->strides()[2], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[0], + dout.strides()[1], + dout.strides()[2], + params.dq_semaphore.data()); #else RaiseNotSupportedError(3); #endif @@ -877,22 +925,20 @@ void FlashAttnGradBaseKernel( dout.strides()[0]); } #endif - if (version != 3) { - CheckFlashAttnStatus(succ); // umiswing: no return status in fa3 - if (!is_mha) { - if (dk) { - if (dk->meta().is_contiguous()) - phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); - else - kvReduceBatchedForGQA(ctx, dk_tmp, dk); - } + CheckFlashAttnStatus(succ); + if (!is_mha) { + if (dk) { + if (dk->meta().is_contiguous()) + phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + else + kvReduceBatchedForGQA(ctx, dk_tmp, dk); + } - if (dv) { - if (dv->meta().is_contiguous()) - phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); - else - kvReduceBatchedForGQA(ctx, dv_tmp, dv); - } + if (dv) { + if (dv->meta().is_contiguous()) + phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + else + kvReduceBatchedForGQA(ctx, dv_tmp, dv); } } #else diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 4bd1b86e158948..49babd255e9a5d 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -31,12 +31,7 @@ #include "paddle/phi/kernels/slice_kernel.h" #include "paddle/utils/none.h" -#ifdef PADDLE_WITH_FLASHATTN_V3 -#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" -#endif - COMMON_DECLARE_int32(flash_attn_version); -COMMON_DECLARE_bool(cudnn_deterministic); namespace phi { template @@ -378,7 +373,7 @@ void FlashAttnBaseKernel( const float softmax_unscale = std::sqrt(head_size); int version = - FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic && + FLAGS_flash_attn_version == 3 && (head_size == 64 || head_size == 128 || head_size == 256) ? FLAGS_flash_attn_version : 2; @@ -537,25 +532,58 @@ void FlashAttnBaseKernel( "FlashMask or Dense Mask is unsupported in FlashAttention V3")); } - FlashAttnV3Kernel(ctx, - q, - k, - v, - paddle::none, // q_v_ - paddle::none, // q_descale_ - paddle::none, // k_descale_ - paddle::none, // v_descale_ - params.softmax_scale, - params.causal, - -1, // window_size_left - -1, // window_size_right - 0.f, // softcap - 1, // num_splits - false, // manual_set_pack_gqa - false, // pack_gqa_ - 0, // sm_margin - out, - softmax_lse); + succ = phi::dynload::flash_attn_v3_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data(), + is_flashmask ? downstart_row_indices_data : nullptr, + is_flashmask ? downend_row_indices_data : nullptr, + is_flashmask ? upend_row_indices_data : nullptr, + is_flashmask ? upstart_row_indices_data : nullptr, + is_flashmask ? flashmask_maxmin.data() : nullptr, + is_flashmask ? params.startend_row_indices_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out->strides()[0], + out->strides()[1], + out->strides()[2], + /*is_e4m3=*/false, + params.tile_count_semaphore.data(), + /*descale_q_ptr=*/nullptr, + /*descale_k_ptr=*/nullptr, + /*descale_v_ptr=*/nullptr, + use_gqa_packing); #else RaiseNotSupportedError(3); #endif @@ -608,9 +636,7 @@ void FlashAttnBaseKernel( out->strides()[0]); } #endif - if (version != 3) { - CheckFlashAttnStatus(succ); // umiswing: no return status in fa3 - } + CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); #endif diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu deleted file mode 100644 index 4aaf9681fbf990..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu +++ /dev/null @@ -1,693 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/common/enforce.h" -#include "paddle/common/flags.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/platform/device_context.h" -#include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#include "paddle/phi/kernels/gpu/flash_attn_v3_utils.h" - -#include "paddle/phi/kernels/concat_kernel.h" -#include "paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.h" -#include "paddle/phi/kernels/slice_kernel.h" - -COMMON_DECLARE_bool(cudnn_deterministic); - -namespace phi { - -// b: batch_size -// s_q: seqlen_q -// s_k: seqlen_k -// h: num_heads -// h_k: num_heads_k -// d: head_size -template -void FlashAttnV3GradBaseKernel( - const Context &ctx, - const DenseTensor - &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const DenseTensor - &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const DenseTensor - &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const DenseTensor - &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const DenseTensor - &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const DenseTensor - &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - const paddle::optional - &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const paddle::optional - &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const paddle::optional - &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const paddle::optional &cu_seqlens_q_, // b+1 - const paddle::optional &cu_seqlens_k_, // b+1 - const paddle::optional - &seqused_q_, // b. If given, only this many elements of each batch - // element's queries and outputs are used. - const paddle::optional - &seqused_k_, // b. If given, only this many elements of each batch - // element's keys are used. - int max_seqlen_q_, - int max_seqlen_k_, - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const deterministic, - int const sm_margin, - DenseTensor *dq, - DenseTensor *dk, - DenseTensor *dv, - DenseTensor *softmax_d, - DenseTensor *softmax_lse_log2, - DenseTensor *dq_accum, - DenseTensor *dk_accum, - DenseTensor *dv_accum) { -#ifdef PADDLE_WITH_FLASHATTN_V3 - - // TODO(umiswing): support ampere - int device_id = ctx.GetPlace().GetDeviceId(); - auto dprops = paddle::platform::GetDeviceProperties(device_id); - const bool is_sm90 = dprops.major == 9 && dprops.minor == 0; - PADDLE_ENFORCE_EQ(is_sm90, - true, - common::errors::Unavailable( - "FlashAttention-3 only supports Hopper GPUs.")); - - auto q_type = q.dtype(); - PADDLE_ENFORCE_EQ( - (q_type == phi::DataType::FLOAT16 || q_type == phi::DataType::BFLOAT16), - true, - common::errors::InvalidArgument( - "FlashAttention-3 bwd only support fp16 and bf16 data type")); - PADDLE_ENFORCE_EQ(k.dtype(), - q_type, - common::errors::InvalidArgument( - "query and key must have the same dtype")); - PADDLE_ENFORCE_EQ(v.dtype(), - q_type, - common::errors::InvalidArgument( - "query and value must have the same dtype")); - PADDLE_ENFORCE_EQ(out.dtype(), - q_type, - common::errors::InvalidArgument( - "query and out must have the same dtype")); - PADDLE_ENFORCE_EQ(dout.dtype(), - q_type, - common::errors::InvalidArgument( - "query and dout must have the same dtype")); - - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); - CHECK_DEVICE(out); - CHECK_DEVICE(dout); - CHECK_DEVICE(softmax_lse); - - PADDLE_ENFORCE_EQ(q.strides()[q.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(k.strides()[k.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(v.strides()[v.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(out.strides()[out.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "out tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(dout.strides()[dout.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "dout tensor must have contiguous last dimension")); - - DenseTensor cu_seqlens_q; - bool const is_varlen_q = cu_seqlens_q_.is_initialized(); - if (is_varlen_q) { - cu_seqlens_q = cu_seqlens_q_.get(); - CHECK_DEVICE(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_q); - PADDLE_ENFORCE_EQ(cu_seqlens_q.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_q must have dtype paddle.int32")); - PADDLE_ENFORCE_GT( - max_seqlen_q_, - 0, - common::errors::InvalidArgument( - "max_seqlen_q must be provided if cu_seqlens_q is provided")); - } - DenseTensor cu_seqlens_k; - bool const is_varlen_k = cu_seqlens_k_.is_initialized(); - if (is_varlen_k) { - cu_seqlens_k = cu_seqlens_k_.get(); - CHECK_DEVICE(cu_seqlens_k); - CHECK_CONTIGUOUS(cu_seqlens_k); - PADDLE_ENFORCE_EQ(cu_seqlens_k.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_k must have dtype paddle.int32")); - PADDLE_ENFORCE_GT( - max_seqlen_k_, - 0, - common::errors::InvalidArgument( - "max_seqlen_k must be provided if cu_seqlens_k is provided")); - } - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || - seqused_q_.is_initialized() || - seqused_k_.is_initialized(); -#ifdef FLASHATTENTION_DISABLE_VARLEN - PADDLE_ENFORCE_EQ(!is_varlen, - true, - common::errors::Unavailable( - "This flash attention build does not support varlen.")); -#endif - - auto const sizes = q.dims(); - int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.dims()[0] - 1; - int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_; - int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; - int const num_heads = q.dims()[q.dims().size() - 2]; - int const head_size = q.dims()[q.dims().size() - 1]; - int const seqlen_k = !is_varlen_k ? k.dims()[1] : max_seqlen_k_; - int const total_k = !is_varlen_k ? batch_size * k.dims()[1] : k.dims()[0]; - int const num_heads_k = k.dims()[k.dims().size() - 2]; - PADDLE_ENFORCE_EQ( - head_size % 8, - 0, - common::errors::InvalidArgument("head_size should be a multiple of 8")); - int const max_headdim = get_max_headdim(); - PADDLE_ENFORCE_LE( - head_size, - max_headdim, - common::errors::InvalidArgument( - "FlashAttention forward only supports head dimension at most %d", - max_headdim)); - PADDLE_ENFORCE_EQ( - num_heads % num_heads_k, - 0, - common::errors::InvalidArgument( - "Number of heads in key/value must divide number of heads in query")); - - // This needs to go before kBlockM & kBlockN since we rely on the correct - // window_size and is_causal to set kBlockM - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - if (window_size_right >= seqlen_q - 1) { - window_size_right = -1; - } - if (is_causal) { - window_size_right = 0; - } - // There's a case where is_causal=false, window_size=(-1, 0). Then - // set_params_bprop will set params.is_causal=true. If we don't have is_causal - // here matching params.is_causal, we might get the wrong kBlockM (and cause - // IMA). - is_causal = window_size_left < 0 && window_size_right == 0; - - int const arch = dprops.major * 10 + dprops.minor; - int const head_size_rounded = round_up_headdim(head_size); - // Very important that these match the kernel configs - bool const is_local = - (window_size_left >= 0 || window_size_right >= 0) && !is_causal; - int const kBlockM_sm90 = - head_size_rounded <= 64 - ? (is_causal && softcap > 0.0 ? 96 : 128) - : (head_size_rounded <= 96 - ? 64 - : (head_size_rounded <= 128 - ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) - : 64)); - int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; - int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; - int const kBlockM = - arch >= 90 ? kBlockM_sm90 - : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); - int const kBlockN_sm90 = - head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 96 : 80); - int const kBlockN_sm80 = - head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64); - int const kBlockN_sm86 = - head_size_rounded <= 64 - ? 128 - : (head_size_rounded <= 96 - ? 128 - : (head_size_rounded <= 128 - ? 96 - : (head_size_rounded <= 192 ? 64 : 64))); - int const kBlockN = - arch >= 90 ? kBlockN_sm90 - : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); - int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); - int const total_q_padded_rounded = - round_multiple(total_q + batch_size * kBlockM, kBlockM); - int const total_k_padded_rounded = - round_multiple(total_k + batch_size * kBlockN, kBlockN); - - if (!is_varlen_q) { - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - } - if (!is_varlen_k) { - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - } - - if (seqused_q_.is_initialized()) { - auto seqused_q = seqused_q_.get(); - PADDLE_ENFORCE_EQ( - seqused_q.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("seqused_q must have dtype int32")); - CHECK_DEVICE(seqused_q); - CHECK_CONTIGUOUS(seqused_q); - CHECK_SHAPE(seqused_q, batch_size); - } - if (seqused_k_.is_initialized()) { - auto seqused_k = seqused_k_.get(); - PADDLE_ENFORCE_EQ( - seqused_k.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("seqused_k must have dtype int32")); - CHECK_DEVICE(seqused_k); - CHECK_CONTIGUOUS(seqused_k); - CHECK_SHAPE(seqused_k, batch_size); - } - - if (dq_.is_initialized()) { - *dq = dq_.get(); - PADDLE_ENFORCE_EQ( - dq->dtype(), - q_type, - common::errors::InvalidArgument("dq must have the same dtype as q")); - CHECK_DEVICE((*dq)); - PADDLE_ENFORCE_EQ(dq->strides()[dq->strides().size() - 1], - 1, - common::errors::InvalidArgument( - "dq must have contiguous last dimension")); - if (!is_varlen_q) { - CHECK_SHAPE((*dq), batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE((*dq), total_q, num_heads, head_size); - } - } else { - *dq = phi::EmptyLike(ctx, q); - } - if (dk_.is_initialized()) { - *dk = dk_.get(); - PADDLE_ENFORCE_EQ( - dk->dtype(), - q_type, - common::errors::InvalidArgument("dk must have the same dtype as q")); - CHECK_DEVICE((*dk)); - PADDLE_ENFORCE_EQ(dk->strides()[dk->strides().size() - 1], - 1, - common::errors::InvalidArgument( - "dk must have contiguous last dimension")); - if (!is_varlen_k) { - CHECK_SHAPE((*dk), batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE((*dk), total_k, num_heads_k, head_size); - } - } else { - *dk = phi::EmptyLike(ctx, k); - } - if (dv_.is_initialized()) { - *dv = dv_.get(); - PADDLE_ENFORCE_EQ( - dv->dtype(), - q_type, - common::errors::InvalidArgument("dv must have the same dtype as q")); - CHECK_DEVICE((*dv)); - PADDLE_ENFORCE_EQ(dv->strides()[dv->strides().size() - 1], - 1, - common::errors::InvalidArgument( - "dv must have contiguous last dimension")); - if (!is_varlen_k) { - CHECK_SHAPE((*dv), batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE((*dv), total_k, num_heads_k, head_size); - } - } else { - *dv = phi::EmptyLike(ctx, v); - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - - // Need softmax_d to have total_q_padded_rounded since we want its address to - // be aligned by 16/8 bytes for TMA / LDG.64 - if (!is_varlen) { - if (softmax_d) { - // Need softmax_d to have seqlen_q_rounded since we want its address to be - // aligned by 16/8 bytes for TMA / LDG.64 - softmax_d->Resize( - common::make_ddim({batch_size, num_heads, seqlen_q_rounded})); - } - if (softmax_lse_log2) { - softmax_lse_log2->Resize( - common::make_ddim({batch_size, num_heads, seqlen_q_rounded})); - } - } else { - if (softmax_d) { - softmax_d->Resize(common::make_ddim({num_heads, total_q_padded_rounded})); - } - if (softmax_lse_log2) { - softmax_lse_log2->Resize( - common::make_ddim({num_heads, total_q_padded_rounded})); - } - } - if (softmax_d) { - ctx.template Alloc(softmax_d); - } - if (softmax_lse_log2) { - ctx.template Alloc(softmax_lse_log2); - } - if (dq_accum) { - if (!is_varlen) { - dq_accum->Resize(common::make_ddim( - {batch_size, num_heads, seqlen_q_rounded * head_size_rounded})); - } else { - dq_accum->Resize(common::make_ddim( - {num_heads, total_q_padded_rounded * head_size_rounded})); - } - ctx.template Alloc(dq_accum); - } - if (num_heads_k != num_heads) { // MQA / GQA - if (!is_varlen) { - if (dk_accum) { - dk_accum->Resize(common::make_ddim( - {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded})); - } - if (dv_accum) { - dv_accum->Resize(common::make_ddim( - {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded})); - } - } else { - if (dk_accum) { - dk_accum->Resize(common::make_ddim( - {num_heads_k, total_k_padded_rounded, head_size_rounded})); - } - if (dv_accum) { - dv_accum->Resize(common::make_ddim( - {num_heads_k, total_k_padded_rounded, head_size_rounded})); - } - } - if (dk_accum) { - ctx.template Alloc(dk_accum); - } - if (dv_accum) { - ctx.template Alloc(dv_accum); - } - phi::funcs::SetConstant set_zero; - - if (dk_accum) { - set_zero(ctx, dk_accum, float{0}); - } - if (dv_accum) { - set_zero(ctx, dv_accum, float{0}); - } - } - - Flash_bwd_params *params_handle = get_flash_bwd_params_handle(); - dynload::fa3_clear_bwd_params_handle(params_handle); - set_params_dgrad( - params_handle, - batch_size, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - seqlen_k_rounded, - num_heads, - num_heads_k, - head_size, - head_size_rounded, - q, - k, - v, - out, - dout, - dq, - dk, - dv, - !is_varlen_q ? nullptr : cu_seqlens_q.data(), - !is_varlen_k ? nullptr : cu_seqlens_k.data(), - seqused_q_.is_initialized() ? const_cast(seqused_q_.get().data()) - : nullptr, - seqused_k_.is_initialized() ? const_cast(seqused_k_.get().data()) - : nullptr, - dq_accum ? dq_accum->data() : nullptr, - num_heads_k != num_heads && dk_accum ? dk_accum->data() : nullptr, - num_heads_k != num_heads && dv_accum ? dv_accum->data() : nullptr, - const_cast(softmax_lse.data()), - softmax_d ? const_cast(softmax_d->data()) : nullptr, - /*p_dropout=*/0.f, - softmax_scale, - window_size_left, - window_size_right, - dprops, - softcap, - deterministic, - sm_margin); - dynload::fa3_bwd_params_set_total_q(params_handle, total_q); - dynload::fa3_bwd_params_set_total_k(params_handle, total_k); - dynload::fa3_bwd_params_set_softmax_lse_log2_ptr( - params_handle, softmax_lse_log2 ? softmax_lse_log2->data() : nullptr); - dynload::fa3_bwd_params_set_dv(params_handle, - head_size); // We don't support hdim_v being - // different from hdim_qk for now - - // auto tile_count_semaphore = (params.is_causal || params.is_local) ? - // paddle::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, - // opts.dtype(torch::kInt32)); params.tile_count_semaphore = - // tile_count_semaphore.data_ptr(); Will be zero'ed out in the backward - // preprocess kernel - DenseTensor dq_semaphore = phi::Empty( - ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}); - dynload::fa3_bwd_params_set_dq_semaphore(params_handle, - dq_semaphore.data()); - if (num_heads_k != num_heads && - dynload::fa3_bwd_params_get_deterministic(params_handle)) { - // TODO(tridao): do we need to zero them out? - DenseTensor dk_semaphore = phi::Empty( - ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); - DenseTensor dv_semaphore = phi::Empty( - ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); - dynload::fa3_bwd_params_set_dk_semaphore(params_handle, - dk_semaphore.data()); - dynload::fa3_bwd_params_set_dv_semaphore(params_handle, - dv_semaphore.data()); - } - -#ifdef FLASHATTENTION_DISABLE_LOCAL - PADDLE_ENABLE_EQ( - !dynload::fa3_bwd_params_get_is_local(params_handle), - true, - "This flash attention build does not support local attention."); -#endif -#ifdef FLASHATTENTION_DISABLE_SOFTCAP - PADDLE_ENABLE_EQ( - dynload::fa3_bwd_params_get_softcap(params_handle), - 0.0, - "This flash attention build does not support tanh softcapping."); -#endif - - if (total_q > 0 && total_k > 0 && num_heads_k > 0) { - dynload::fa3_run_mha_bwd(params_handle, ctx.stream()); - } else if (total_k > 0 && num_heads_k > 0) { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output - // to 0. - phi::funcs::SetConstant set_zero; - set_zero(ctx, dk, T{0}); - set_zero(ctx, dv, T{0}); - if (softmax_d) { - phi::funcs::SetConstant set_zero_fp32; - set_zero_fp32(ctx, softmax_d, float{0}); - } - } else if (total_q > 0 && num_heads_k > 0) { - phi::funcs::SetConstant set_zero; - set_zero(ctx, dq, T{0}); - if (softmax_d) { - phi::funcs::SetConstant set_zero_fp32; - set_zero_fp32(ctx, softmax_d, float{0}); - } - } -#else - RaiseNotSupportedError(); -#endif -} - -template -void FlashAttnV3GradKernel(const Context &ctx, - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const DenseTensor &out, - const DenseTensor &softmax_lse, - const DenseTensor &out_grad, - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - int const sm_margin, - DenseTensor *dq, - DenseTensor *dk, - DenseTensor *dv) { -#ifdef PADDLE_WITH_FLASHATTN_V3 - PADDLE_ENFORCE_EQ( - window_size_left, - -1, - common::errors::InvalidArgument("window_size is not supported, please " - "set window_size_left/right to -1")); - PADDLE_ENFORCE_EQ( - window_size_right, - -1, - common::errors::InvalidArgument("window_size is not supported, please " - "set window_size_left/right to -1")); - PADDLE_ENFORCE_EQ(softcap, - 0, - common::errors::InvalidArgument( - "softcap is not supported, please set softcap to 0")); - PADDLE_ENFORCE_EQ( - sm_margin, - 0, - common::errors::InvalidArgument( - "sm_margin is not supported, please set sm_margin to 0")); - PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic, - false, - common::errors::InvalidArgument( - "deterministic is not supported in flash attention 3, " - "please set FLAGS_cudnn_deterministic to false")); - // umiswing: fake grad tensor for FlashAttnV3GradBaseKernel - DenseTensor softmax_d; - DenseTensor softmax_lse_log2; - DenseTensor dq_accum; - DenseTensor dk_accum; - DenseTensor dv_accum; - // TODO(umiswing): remove padding in mla - DenseTensor v_padded; - DenseTensor out_padded; - DenseTensor out_grad_padded; - DenseTensor dv_padded; - const int64_t b = q.dims()[0]; - const int64_t s_q = q.dims()[1]; - const int64_t s_k = k.dims()[1]; - const int64_t h_q = q.dims()[2]; - const int64_t h_k = k.dims()[2]; - const int64_t d_q = q.dims()[3]; - const int64_t d_v = v.dims()[3]; - if (q.dims()[q.dims().size() - 1] > v.dims()[v.dims().size() - 1]) { - PADDLE_ENFORCE_EQ(v.dims()[v.dims().size() - 1], - out.dims()[out.dims().size() - 1], - common::errors::InvalidArgument( - "head_dim_v and head_dim_o must be equal")); - PADDLE_ENFORCE_EQ(v.dims()[v.dims().size() - 2], - out.dims()[out.dims().size() - 2], - common::errors::InvalidArgument( - "num_heads_v and num_heads_o must be equal")); - PADDLE_ENFORCE_EQ( - v.dims()[v.dims().size() - 3], - out.dims()[out.dims().size() - 3], - common::errors::InvalidArgument("seqlen_v and seqlen_o must be equal")); - DenseTensor padding = Empty(ctx, {b, s_k, h_k, d_q - d_v}); - funcs::SetConstant set_zero; - set_zero(ctx, &padding, T{0}); - ConcatKernel(ctx, {&v, &padding}, {3}, &v_padded); - ConcatKernel(ctx, {&out, &padding}, {3}, &out_padded); - ConcatKernel(ctx, {&out_grad, &padding}, {3}, &out_grad_padded); - } else { - v_padded = v; - out_padded = out; - out_grad_padded = out_grad; - } - FlashAttnV3GradBaseKernel(ctx, - out_grad_padded, - q, - k, - v_padded, - out_padded, - softmax_lse, - paddle::none, - paddle::none, - paddle::none, - paddle::none, - paddle::none, - paddle::none, - paddle::none, - 0, - 0, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - FLAGS_cudnn_deterministic, - sm_margin, - dq, - dk, - &dv_padded, - &softmax_d, - &softmax_lse_log2, - &dq_accum, - &dk_accum, - &dv_accum); - - if (q.dims()[q.dims().size() - 1] > v.dims()[v.dims().size() - 1]) { - *dv = Slice(ctx, dv_padded, {3}, {0}, {d_v}); - } else { - *dv = dv_padded; - } -#else - RaiseNotSupportedError(); -#endif -} - -} // namespace phi - -PD_REGISTER_KERNEL(flash_attn_v3_grad, - GPU, - ALL_LAYOUT, - phi::FlashAttnV3GradKernel, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.h b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.h deleted file mode 100644 index 65845b77799256..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -namespace phi { -template -void FlashAttnV3GradKernel( - const Context &ctx, - const DenseTensor - &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const DenseTensor - &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const DenseTensor - &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const DenseTensor - &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const DenseTensor - &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - const DenseTensor & - out_grad, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - int const sm_margin, - DenseTensor *dq, - DenseTensor *dk, - DenseTensor *dv); -} // namespace phi diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu deleted file mode 100644 index 56359e826644d8..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu +++ /dev/null @@ -1,1081 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/flash_attn_kernel.h" - -#include -#include "glog/logging.h" // For VLOG() -#include "paddle/common/enforce.h" -#include "paddle/common/errors.h" -#include "paddle/common/flags.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/platform/device_context.h" -#include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/core/utils/data_type.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/slice_kernel.h" -#include "paddle/utils/none.h" - -#include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#include "paddle/phi/kernels/gpu/flash_attn_v3_utils.h" - -#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" - -namespace phi { - -template -void FlashAttnV3BaseKernel( - const Context &ctx, - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const paddle::optional - &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is - // cu_seqlens_k_new - const paddle::optional - &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is - // cu_seqlens_k_new - const paddle::optional - &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is - // cu_seqlens_q - const paddle::optional - &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - const paddle::optional &cu_seqlens_q_, // b+1 - const paddle::optional &cu_seqlens_k_, // b+1 - const paddle::optional &cu_seqlens_k_new_, // b+1 - const paddle::optional - &seqused_q_, // b. If given, only this many elements of each batch - // element's queries and outputs are used. - const paddle::optional - &seqused_k_, // b. If given, only this many elements of each batch - // element's keys are used. - const paddle::optional - &page_table_, // (b_k, max_num_pages_per_seq) - const paddle::optional - &kv_batch_idx_, // b. indices to index into the KV cache - const paddle::optional &leftpad_k_, // b - const paddle::optional - &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - const paddle::optional - &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - const paddle::optional &q_descale_, // (b, h_k), not (b, h) - const paddle::optional &k_descale_, // (b, h_k) - const paddle::optional &v_descale_, // (b, h_k) - const paddle::optional &scheduler_metadata_, // (b + 1) - const int - max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that it is - // uninitialized and should not be referenced - // TODO(tridao): check if we need max_seqlen_k - const int - max_seqlen_k_, // if max_seqlen_q_ is set to 0, it indicates that it is - // uninitialized and should not be referenced - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - const bool is_rotary_interleaved, // if true, rotary combines indices 0 & - // 1, else indices 0 & rotary_dim / 2 - int num_splits, - const bool manual_set_pack_gqa, - const bool - pack_gqa_, // the pack_gqa_ will be used only if manual_set_pack_gqa is - // set to True; otherwise, the internal heuristic - // get_pack_gqa() from fa3 will decide whether to pack gqa - const int sm_margin, - DenseTensor *out, - DenseTensor *softmax_lse, - DenseTensor *out_accum, - DenseTensor *softmax_lse_accum) { -#ifdef PADDLE_WITH_FLASHATTN_V3 - // TODO(umiswing): support ampere - int device_id = ctx.GetPlace().GetDeviceId(); - auto dprops = paddle::platform::GetDeviceProperties(device_id); - const bool is_sm90 = dprops.major == 9 && dprops.minor == 0; - PADDLE_ENFORCE_EQ(is_sm90, - true, - common::errors::Unavailable( - "FlashAttention-3 only supports Hopper GPUs.")); - - auto q_type = q.dtype(); - PADDLE_ENFORCE_EQ( - (q_type == phi::DataType::FLOAT16 || q_type == phi::DataType::BFLOAT16 || - q_type == phi::DataType::FLOAT8_E4M3FN), - true, - common::errors::InvalidArgument( - "FlashAttention-3 only supports fp16, bf16, and fp8_e4m3 data type")); - - PADDLE_ENFORCE_EQ(k.dtype(), - q_type, - common::errors::InvalidArgument( - "query and key must have the same dtype")); - PADDLE_ENFORCE_EQ(v.dtype(), - q_type, - common::errors::InvalidArgument( - "query and value must have the same dtype")); - - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); - - PADDLE_ENFORCE_EQ(q.strides()[q.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(k.strides()[k.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(v.strides()[v.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Input tensor must have contiguous last dimension")); - - DenseTensor page_table; - // const bool paged_KV = page_table_.has_value(); - // umiswing: this is stupid but idk how to use paddle::optional - const bool paged_KV = page_table_.is_initialized(); - if (paged_KV) { - page_table = page_table_.get(); - CHECK_DEVICE(page_table); - PADDLE_ENFORCE_EQ(page_table.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "page_table must have dtype paddle.int32")); - PADDLE_ENFORCE_EQ(page_table.strides()[page_table.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "page_table must have contiguous last dimension")); - } - - // TODO(umiswing): support cusum - - DenseTensor cu_seqlens_q; - // bool const is_varlen_q = cu_seqlens_q_.has_value(); - // TODO(umiswing): this is stupid, must fix it (after understand - // paddle::optional) - const bool is_varlen_q = cu_seqlens_q_.is_initialized(); - if (is_varlen_q) { - cu_seqlens_q = cu_seqlens_q_.get(); - CHECK_DEVICE(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_q); - PADDLE_ENFORCE_EQ(cu_seqlens_q.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_q must have dtype paddle.int32")); - PADDLE_ENFORCE_NE( - max_seqlen_q_, - 0, - common::errors::InvalidArgument( - "max_seqlen_q must be provided if cu_seqlens_q is provided")); - } - - DenseTensor cu_seqlens_k; - const bool is_varlen_k = cu_seqlens_k_.is_initialized(); - if (is_varlen_k) { - cu_seqlens_k = cu_seqlens_k_.get(); - CHECK_DEVICE(cu_seqlens_k); - CHECK_CONTIGUOUS(cu_seqlens_k); - PADDLE_ENFORCE_EQ(cu_seqlens_k.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_k must have dtype paddle.int32")); - PADDLE_ENFORCE_NE( - max_seqlen_k_, - 0, - common::errors::InvalidArgument( - "max_seqlen_k must be provided if cu_seqlens_k is provided")); - PADDLE_ENFORCE_EQ( - !paged_KV, - true, - common::errors::InvalidArgument( - "If cu_seqlens_k is passed in, then page table is not supported")); - PADDLE_ENFORCE_EQ( - !kv_batch_idx_, - true, - common::errors::InvalidArgument( - "If cu_seqlens_k is passed in, then page table is not supported")); - } - - auto const sizes = q.dims(); - const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.dims()[0] - 1; - int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_; - int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; - int num_heads = q.dims()[q.dims().size() - 2]; - int const head_size = q.dims()[q.dims().size() - 1]; - int const head_size_v = v.dims()[v.dims().size() - 1]; - int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.dims()[1]; - int const num_pages = !paged_KV ? 0 : k.dims()[0]; - int const page_size = !paged_KV ? 1 : k.dims()[1]; - int const seqlen_k = - !is_varlen_k - ? (!paged_KV ? k.dims()[1] : max_num_pages_per_seq * page_size) - : max_seqlen_k_; - int const total_k = !is_varlen_k ? batch_size * k.dims()[1] : k.dims()[0]; - int const num_heads_k = k.dims()[k.dims().size() - 2]; - int const batch_size_k = - !paged_KV ? (!is_varlen_k ? k.dims()[0] : cu_seqlens_k.dims()[0] - 1) - : page_table.dims()[0]; - if (!kv_batch_idx_.is_initialized()) { - PADDLE_ENFORCE_EQ(batch_size, - batch_size_k, - common::errors::InvalidArgument( - "batch_size must be equal to batch_size_k")); - } - int const max_headdim = get_max_headdim(); - PADDLE_ENFORCE_LE( - head_size, - max_headdim, - common::errors::InvalidArgument( - "FlashAttention forward only supports head dimension at most %d", - max_headdim)); - PADDLE_ENFORCE_EQ( - num_heads % num_heads_k, - 0, - common::errors::InvalidArgument( - "Number of heads in key/value must divide number of heads in query")); - if (head_size_v != head_size) { - PADDLE_ENFORCE_EQ( - ((head_size > 128 && head_size <= 192 && head_size_v > 96 && - head_size_v <= 128) || - (head_size <= 64 && head_size_v <= 512)), - true, - common::errors::InvalidArgument( - "If V headdim is different from Q/K dim, we only support " - "Q/K headdim in (128, 192] and V headdim in (96, 128], " - "or (Q/K <= 64 and V <= 512).")); - PADDLE_ENFORCE_EQ(dprops.major, - 9, - common::errors::InvalidArgument( - "Only Hopper supports different V headdim")); - if (head_size_v > 256) { - PADDLE_ENFORCE_EQ((q_type == phi::DataType::FLOAT16 || - q_type == phi::DataType::BFLOAT16), - true, - common::errors::InvalidArgument( - "HeaddimV > 256 requires fp16 and bf16 data type")); - } - } - - // This needs to go before kBlockM & kBlockN since we rely on the correct - // window_size and is_causal to set kBlockM - // TODO(tridao): check this - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - if (window_size_right >= seqlen_q - 1) { - window_size_right = -1; - } - // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { - // Special case of hdim 128 where we want causal to have kBlockN=128, better - // for pagedKV and TMA - if ((head_size <= 64 || head_size > 128) || !paged_KV) { - is_causal = false; - } - } - if (is_causal) { - window_size_right = 0; - } - // There's a case where is_causal=false, window_size=(-1, 0). Then - // set_params_fprop will set params.is_causal=true. If we don't have is_causal - // here matching params.is_causal, we might get the wrong kBlockM. - is_causal = window_size_left < 0 && window_size_right == 0; - - if (!is_varlen_q) { - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - } - if (!paged_KV) { - if (!is_varlen_k) { - CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); - } else { - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - } - } else { - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); - } - - if (seqused_q_.is_initialized()) { - auto seqused_q = seqused_q_.get(); - PADDLE_ENFORCE_EQ( - seqused_q.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("seqused_q must have dtype int32")); - CHECK_DEVICE(seqused_q); - CHECK_CONTIGUOUS(seqused_q); - CHECK_SHAPE(seqused_q, batch_size); - } - if (seqused_k_.is_initialized()) { - auto seqused_k = seqused_k_.get(); - PADDLE_ENFORCE_EQ( - seqused_k.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("seqused_k must have dtype int32")); - CHECK_DEVICE(seqused_k); - CHECK_CONTIGUOUS(seqused_k); - CHECK_SHAPE(seqused_k, batch_size); - } - - if (leftpad_k_.is_initialized()) { - auto leftpad_k = leftpad_k_.get(); - PADDLE_ENFORCE_EQ( - leftpad_k.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("leftpad_k must have dtype int32")); - CHECK_DEVICE(leftpad_k); - CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - - // This is what we will template on - bool const is_varlen = - is_varlen_q || is_varlen_k || seqused_q_.is_initialized() || - seqused_k_.is_initialized() || leftpad_k_.is_initialized(); -#ifdef FLASHATTENTION_DISABLE_VARLEN - PADDLE_ENFORCE_EQ(!is_varlen, - true, - common::errors::Unavailable( - "This flash attention build does not support varlen.")); -#endif - - int const alignment = q_type == phi::DataType::FLOAT8_E4M3FN ? 16 : 8; - PADDLE_ENFORCE_EQ(head_size % alignment, - 0, - common::errors::InvalidArgument( - "head_size should be a multiple of %d", alignment)); - PADDLE_ENFORCE_EQ(head_size_v % alignment, - 0, - common::errors::InvalidArgument( - "head_size_v should be a multiple of %d", alignment)); - - auto out_type = - q_type == phi::DataType::FLOAT8_E4M3FN ? phi::DataType::BFLOAT16 : q_type; - if (out_.is_initialized()) { - *out = out_.get(); - PADDLE_ENFORCE_EQ( - out->dtype(), - out_type, - common::errors::InvalidArgument( - "For FP16/BF16 input, output must have the same dtype as " - "inputs. For FP8 input, output must have dtype BF16")); - CHECK_DEVICE((*out)); - PADDLE_ENFORCE_EQ(out->strides()[out->strides().size() - 1], - 1, - common::errors::InvalidArgument( - "Output tensor must have contiguous last dimension")); - if (!is_varlen_q) { - CHECK_SHAPE((*out), batch_size, seqlen_q, num_heads, head_size_v); - } else { - CHECK_SHAPE((*out), total_q, num_heads, head_size_v); - } - } else { - if (!is_varlen_q) { - out->Resize( - common::make_ddim({batch_size, seqlen_q, num_heads, head_size_v})); - } else { - out->Resize(common::make_ddim({total_q, num_heads, head_size_v})); - } - if (q_type == phi::DataType::FLOAT8_E4M3FN) { - ctx.template Alloc(out); - } else { - // umiswing: assuming T is Input Type - ctx.template Alloc(out); - } - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = round_up_headdim(head_size_v); - int const seqlen_q_rounded = round_multiple(seqlen_q, 128); - int const seqlen_k_rounded = round_multiple(seqlen_k, 128); - - if (!is_varlen_q) { - softmax_lse->Resize(common::make_ddim({batch_size, num_heads, seqlen_q})); - } else { - softmax_lse->Resize(common::make_ddim({num_heads, total_q})); - } - ctx.template Alloc(softmax_lse); - - Flash_fwd_params *params_handle = get_flash_fwd_params_handle(); - dynload::fa3_clear_fwd_params_handle(params_handle); - set_params_fprop( - params_handle, - batch_size, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - seqlen_k_rounded, - num_heads, - num_heads_k, - head_size, - head_size_rounded, - q, - k, - v, - out, - !is_varlen_q ? nullptr : cu_seqlens_q.data(), - !is_varlen_k ? nullptr : cu_seqlens_k.data(), - seqused_q_.is_initialized() ? const_cast(seqused_q_.get().data()) - : nullptr, - seqused_k_.is_initialized() ? const_cast(seqused_k_.get().data()) - : nullptr, - softmax_lse->data(), - /*p_dropout=*/0.f, - softmax_scale, - window_size_left, - window_size_right, - dprops, - softcap, - sm_margin); - phi::dynload::fa3_fwd_params_set_total_q(params_handle, total_q); - phi::dynload::fa3_fwd_params_set_total_k(params_handle, total_k); - phi::dynload::fa3_fwd_params_set_b_k(params_handle, batch_size_k); - phi::dynload::fa3_fwd_params_set_dv(params_handle, head_size_v); - phi::dynload::fa3_fwd_params_set_dv_rounded(params_handle, - head_size_v_rounded); - - if (leftpad_k_ - .is_initialized()) { // This needs to be set before get_pagedkv_tma - phi::dynload::fa3_fwd_params_set_leftpad_k(params_handle, - leftpad_k_.get().data()); - } - if (paged_KV) { - phi::dynload::fa3_fwd_params_set_page_table(params_handle, - page_table.data()); - phi::dynload::fa3_fwd_params_set_page_table_batch_stride( - params_handle, page_table.strides()[0]); - } - phi::dynload::fa3_fwd_params_set_page_size(params_handle, page_size); - phi::dynload::fa3_fwd_params_set_num_pages(params_handle, num_pages); - - if (k_new_.is_initialized()) { // This needs to be set before get_pagedkv_tma - DenseTensor k_new, v_new; - PADDLE_ENFORCE_EQ( - v_new_.is_initialized(), - true, - common::errors::InvalidArgument( - "If k_new is supplied, v_new must also be passed in")); - PADDLE_ENFORCE_EQ( - seqused_k_.is_initialized(), - true, - common::errors::InvalidArgument( - "If k_new is supplied, seqlens_k must also be passed in")); - PADDLE_ENFORCE_LE( - seqlen_q, - seqlen_k, - common::errors::InvalidArgument( - "If k_new is supplied, it must have seqlen <= the seqlen " - "of the KV cache")); - DenseTensor cu_seqlens_k_new; - bool const is_varlen_k_new = cu_seqlens_k_new_.is_initialized(); - if (is_varlen_k_new) { - cu_seqlens_k_new = cu_seqlens_k_new_.get(); - CHECK_DEVICE(cu_seqlens_k_new); - CHECK_CONTIGUOUS(cu_seqlens_k_new); - PADDLE_ENFORCE_EQ(cu_seqlens_k_new.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_k_new must have dtype paddle.int32")); - } - k_new = k_new_.get(); - v_new = v_new_.get(); - PADDLE_ENFORCE_EQ(k_new.dtype(), - q_type, - common::errors::InvalidArgument( - "k_new must have the same dtype as query")); - PADDLE_ENFORCE_EQ(v_new.dtype(), - q_type, - common::errors::InvalidArgument( - "v_new must have the same dtype as query")); - CHECK_DEVICE(k_new); - CHECK_DEVICE(v_new); - PADDLE_ENFORCE_EQ(k_new.strides()[k_new.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "k_new tensor must have contiguous last dimension")); - PADDLE_ENFORCE_EQ(v_new.strides()[v_new.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "v_new tensor must have contiguous last dimension")); - // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when - // is_varlen_k_new - int seqlen_k_new = !is_varlen_k_new ? k_new.dims()[1] : 0; - int total_k_new = - !is_varlen_k_new ? batch_size * k_new.dims()[1] : k_new.dims()[0]; - if (!is_varlen_k_new) { - CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); - } else { - CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); - CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); - } - // umiswing: dump this to shared library - phi::dynload::fa3_fwd_params_set_seqlen_knew(params_handle, seqlen_k_new); - phi::dynload::fa3_fwd_params_set_total_knew(params_handle, total_k_new); - phi::dynload::fa3_fwd_params_set_knew_ptr(params_handle, - const_cast(k_new.data())); - phi::dynload::fa3_fwd_params_set_vnew_ptr(params_handle, - const_cast(v_new.data())); - // All stride are in elements, not bytes. - phi::dynload::fa3_fwd_params_set_knew_row_stride( - params_handle, k_new.strides()[k_new.strides().size() - 3]); - phi::dynload::fa3_fwd_params_set_vnew_row_stride( - params_handle, v_new.strides()[v_new.strides().size() - 3]); - phi::dynload::fa3_fwd_params_set_knew_head_stride( - params_handle, k_new.strides()[k_new.strides().size() - 2]); - phi::dynload::fa3_fwd_params_set_vnew_head_stride( - params_handle, v_new.strides()[v_new.strides().size() - 2]); - if (!is_varlen_k_new) { - phi::dynload::fa3_fwd_params_set_knew_batch_stride(params_handle, - k_new.strides()[0]); - phi::dynload::fa3_fwd_params_set_vnew_batch_stride(params_handle, - v_new.strides()[0]); - } - if (is_varlen_k_new) { - phi::dynload::fa3_fwd_params_set_cu_seqlens_knew( - params_handle, cu_seqlens_k_new.data()); - } - } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks - // kernel - bool const use_dynamic_split = - is_varlen && phi::dynload::fa3_fwd_params_get_b(params_handle) <= 992; - // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - phi::dynload::fa3_fwd_params_set_num_splits_dynamic_ptr( - params_handle, !use_dynamic_split ? nullptr : reinterpret_cast(1)); - - phi::dynload::fa3_fwd_params_set_pagedkv_tma( - params_handle, phi::dynload::fa3_get_pagedkv_tma(params_handle)); - if (num_splits <= 0) { - num_splits = phi::dynload::fa3_get_num_splits(params_handle); - } - phi::dynload::fa3_fwd_params_set_num_splits(params_handle, num_splits); - - // Always enable PackGQA for Split, and get_pack_gqa requires - // params.num_splits to decide - const bool pack_gqa = manual_set_pack_gqa - ? pack_gqa_ - : phi::dynload::fa3_get_pack_gqa(params_handle); - phi::dynload::fa3_fwd_params_set_pack_gqa(params_handle, pack_gqa); - - // This needs to be set after get_num_splits - DenseTensor tile_count_semaphore; // Contains the semaphore and optionally - // num_splits_dynamic - // We don't use the persistent scheduler if Split and not Varlen - const bool params_is_causal = - phi::dynload::fa3_fwd_params_get_is_causal(params_handle); - const bool params_is_local = - phi::dynload::fa3_fwd_params_get_is_local(params_handle); - const int params_num_splits = - phi::dynload::fa3_fwd_params_get_num_splits(params_handle); - const int params_b = phi::dynload::fa3_fwd_params_get_b(params_handle); - const int params_arch = phi::dynload::fa3_fwd_params_get_arch(params_handle); - bool const scheduler_needs_semaphore = - params_arch >= 90 ? (((params_is_causal || params_is_local) && - (params_num_splits == 1)) || - is_varlen) - : ((params_is_causal && !is_varlen) || - (is_varlen && params_num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = static_cast(scheduler_needs_semaphore) + - static_cast(use_dynamic_split) * params_b; - phi::dynload::fa3_fwd_params_set_skip_scheduler_metadata_computation( - params_handle, scheduler_metadata_.is_initialized()); - if (scheduler_metadata_.is_initialized()) { - DenseTensor scheduler_metadata = scheduler_metadata_.get(); - CHECK_DEVICE(scheduler_metadata); - CHECK_SHAPE(scheduler_metadata, metadata_size); - CHECK_CONTIGUOUS(scheduler_metadata); - PADDLE_ENFORCE_EQ(scheduler_metadata.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "scheduler_metadata must have dtype int32")); - tile_count_semaphore = scheduler_metadata; - } else { - tile_count_semaphore = phi::Empty(ctx, {metadata_size}); - } - if (scheduler_needs_semaphore && !use_dynamic_split) { - phi::funcs::SetConstant set_zero; - set_zero(ctx, - &tile_count_semaphore, - int32_t{0}); // If varlen we'll manually do the zero-ing - } - phi::dynload::fa3_fwd_params_set_tile_count_semaphore( - params_handle, - scheduler_needs_semaphore - ? const_cast(tile_count_semaphore.data()) - : nullptr); - phi::dynload::fa3_fwd_params_set_num_splits_dynamic_ptr( - params_handle, - use_dynamic_split - ? const_cast(tile_count_semaphore.data()) + 1 - : nullptr); - } - - if (q_v_.is_initialized()) { - PADDLE_ENFORCE_LT(head_size, - 64, - common::errors::InvalidArgument( - "q_v is only supported for head_size <= 64")); - PADDLE_ENFORCE_EQ( - (q_type == phi::DataType::FLOAT16 || q_type == phi::DataType::FLOAT16), - true, - common::errors::InvalidArgument( - "q_v is only supported for fp16 and bf16 data type")); - PADDLE_ENFORCE_EQ(params_arch, - 90, - common::errors::InvalidArgument( - "q_v is only supported for Hopper GPUs")); - DenseTensor q_v = q_v_.get(); - PADDLE_ENFORCE_EQ(q_v.dtype(), - q_type, - common::errors::InvalidArgument( - "q_v must have the same dtype as query")); - CHECK_DEVICE(q_v); - PADDLE_ENFORCE_EQ(q_v.strides()[q_v.strides().size() - 1], - 1, - common::errors::InvalidArgument( - "q_v tensor must have contiguous last dimension")); - if (!is_varlen_q) { - CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); - } else { - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - } - phi::dynload::fa3_fwd_params_set_qv_ptr(params_handle, - const_cast(q_v.data())); - // All stride are in elements, not bytes. - phi::dynload::fa3_fwd_params_set_qv_row_stride( - params_handle, q_v.strides()[q_v.strides().size() - 3]); - phi::dynload::fa3_fwd_params_set_qv_head_stride( - params_handle, q_v.strides()[q_v.strides().size() - 2]); - if (!is_varlen_q) { - phi::dynload::fa3_fwd_params_set_qv_batch_stride(params_handle, - q_v.strides()[0]); - } - } - - if (rotary_cos_.is_initialized()) { - PADDLE_ENFORCE_EQ( - k_new_.is_initialized(), - true, - common::errors::InvalidArgument( - "If rotary cos/sin are provided, new key / value to be " - "appended to KV cache must also be provided")); - DenseTensor rotary_cos = rotary_cos_.get(); - CHECK_DEVICE(rotary_cos); - CHECK_CONTIGUOUS(rotary_cos); - int params_rotary_dim = rotary_cos.dims()[1] * 2; - phi::dynload::fa3_fwd_params_set_rotary_dim(params_handle, - params_rotary_dim); - PADDLE_ENFORCE_LE( - params_rotary_dim, - head_size, - common::errors::InvalidArgument("rotary_dim must be <= headdim")); - PADDLE_ENFORCE_EQ( - params_rotary_dim % 16, - 0, - common::errors::InvalidArgument( - "Only rotary dimensions divisible by 16 are currently supported")); - const int seqlen_ro = rotary_cos.dims()[0]; - if (paged_KV) { - PADDLE_ENFORCE_GE( - seqlen_ro, - seqlen_k, - common::errors::InvalidArgument( - "cos/sin seqlen must be at least the seqlen of KV cache")); - } - CHECK_SHAPE(rotary_cos, seqlen_ro, params_rotary_dim / 2); - PADDLE_ENFORCE_EQ(rotary_cos.dtype(), - q_type, - common::errors::InvalidArgument( - "rotary_cos must have the same dtype as query")); - - PADDLE_ENFORCE_EQ( - rotary_sin_.is_initialized(), - true, - common::errors::InvalidArgument( - "If rotary cos is provided, rotary sin must also be provided")); - auto rotary_sin = rotary_sin_.get(); - CHECK_DEVICE(rotary_sin); - CHECK_CONTIGUOUS(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params_rotary_dim / 2); - PADDLE_ENFORCE_EQ(rotary_sin.dtype(), - q_type, - common::errors::InvalidArgument( - "rotary_cos must have the same dtype as query")); - - phi::dynload::fa3_fwd_params_set_rotary_cos_ptr( - params_handle, const_cast(rotary_cos.data())); - phi::dynload::fa3_fwd_params_set_rotary_sin_ptr( - params_handle, const_cast(rotary_sin.data())); - dynload::fa3_fwd_params_set_is_rotary_interleaved(params_handle, - is_rotary_interleaved); - } else { - phi::dynload::fa3_fwd_params_set_rotary_dim(params_handle, 0); - } - - if (kv_batch_idx_.is_initialized()) { - DenseTensor kv_batch_idx = kv_batch_idx_.get(); - CHECK_DEVICE(kv_batch_idx); - CHECK_CONTIGUOUS(kv_batch_idx); - PADDLE_ENFORCE_EQ( - kv_batch_idx.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument("kv_batch_idx must have dtype int32")); - phi::dynload::fa3_fwd_params_set_kv_batch_idx( - params_handle, reinterpret_cast(kv_batch_idx.data())); - } - - if (phi::dynload::fa3_fwd_params_get_num_splits(params_handle) > 1) { - PADDLE_ENFORCE_LE( - phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - 256, - common::errors::InvalidArgument("num_splits > 256 not supported")); - if (!is_varlen_q) { - out_accum->Resize(common::make_ddim( - {phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - batch_size, - num_heads, - seqlen_q, - head_size_v})); - softmax_lse_accum->Resize(common::make_ddim( - {phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - batch_size, - num_heads, - seqlen_q})); - ctx.template Alloc(out_accum); - ctx.template Alloc(softmax_lse_accum); - phi::dynload::fa3_fwd_params_set_oaccum_batch_stride( - params_handle, out_accum->strides()[1]); - phi::dynload::fa3_fwd_params_set_lseaccum_batch_stride( - params_handle, softmax_lse_accum->strides()[1]); - } else { - out_accum->Resize(common::make_ddim( - {phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - num_heads, - total_q, - head_size_v})); - softmax_lse_accum->Resize(common::make_ddim( - {phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - num_heads, - total_q})); - ctx.template Alloc(out_accum); - ctx.template Alloc(softmax_lse_accum); - } - phi::dynload::fa3_fwd_params_set_is_fp32(params_handle, false); - phi::dynload::fa3_fwd_params_set_oaccum_ptr( - params_handle, const_cast(out_accum->data())); - phi::dynload::fa3_fwd_params_set_softmax_lseaccum_ptr( - params_handle, const_cast(softmax_lse_accum->data())); - phi::dynload::fa3_fwd_params_set_oaccum_split_stride( - params_handle, out_accum->strides()[0]); - phi::dynload::fa3_fwd_params_set_oaccum_row_stride( - params_handle, out_accum->strides()[out_accum->strides().size() - 2]); - phi::dynload::fa3_fwd_params_set_oaccum_head_stride( - params_handle, out_accum->strides()[out_accum->strides().size() - 3]); - phi::dynload::fa3_fwd_params_set_lseaccum_split_stride( - params_handle, softmax_lse_accum->strides()[0]); - phi::dynload::fa3_fwd_params_set_lseaccum_head_stride( - params_handle, - softmax_lse_accum->strides()[softmax_lse_accum->strides().size() - 2]); - } - - if (q_type == phi::DataType::FLOAT8_E4M3FN) { - if (q_descale_.is_initialized()) { - DenseTensor q_descale = q_descale_.get(); - CHECK_DEVICE(q_descale); - CHECK_SHAPE(q_descale, batch_size, num_heads_k); - phi::dynload::fa3_fwd_params_set_q_descale_ptr( - params_handle, const_cast(q_descale.data())); - phi::dynload::fa3_fwd_params_set_q_descale_batch_stride( - params_handle, q_descale.strides()[0]); - phi::dynload::fa3_fwd_params_set_q_descale_head_stride( - params_handle, q_descale.strides()[1]); - } else { - phi::dynload::fa3_fwd_params_set_q_descale_ptr(params_handle, nullptr); - } - if (k_descale_.is_initialized()) { - DenseTensor k_descale = k_descale_.get(); - CHECK_DEVICE(k_descale); - CHECK_SHAPE(k_descale, batch_size, num_heads_k); - phi::dynload::fa3_fwd_params_set_k_descale_ptr( - params_handle, const_cast(k_descale.data())); - phi::dynload::fa3_fwd_params_set_k_descale_batch_stride( - params_handle, k_descale.strides()[0]); - phi::dynload::fa3_fwd_params_set_k_descale_head_stride( - params_handle, k_descale.strides()[1]); - } else { - phi::dynload::fa3_fwd_params_set_k_descale_ptr(params_handle, nullptr); - } - if (v_descale_.is_initialized()) { - DenseTensor v_descale = v_descale_.get(); - CHECK_DEVICE(v_descale); - CHECK_SHAPE(v_descale, batch_size, num_heads_k); - phi::dynload::fa3_fwd_params_set_v_descale_ptr( - params_handle, const_cast(v_descale.data())); - phi::dynload::fa3_fwd_params_set_v_descale_batch_stride( - params_handle, v_descale.strides()[0]); - phi::dynload::fa3_fwd_params_set_v_descale_head_stride( - params_handle, v_descale.strides()[1]); - } else { - phi::dynload::fa3_fwd_params_set_v_descale_ptr(params_handle, nullptr); - } - } - -#ifdef FLASHATTENTION_DISABLE_LOCAL - PADDLE_ENFORCE_EQ( - !phi::dynload::fa3_fwd_params_get_is_local(params_handle), - true, - common::errors::InvalidArgument( - "This flash attention build does not support local attention.")); -#endif -#ifdef FLASHATTENTION_DISABLE_SOFTCAP - PADDLE_ENFORCE_EQ( - phi::dynload::fa3_fwd_params_get_softcap(params_handle), - 0.0, - common::errors::InvalidArgument( - "This flash attention build does not support tanh softcapping.")); -#endif -#ifdef FLASHATTENTION_DISABLE_SPLIT - PADDLE_ENFORCE_EQ(phi::dynload::fa3_fwd_params_get_num_splits(params_handle), - 1, - common::errors::InvalidArgument( - "This flash attention build does not support splits.")); -#endif -#ifdef FLASHATTENTION_DISABLE_PACKGQA - PADDLE_ENFORCE_EQ( - (!phi::dynload::fa3_fwd_params_get_pack_gqa(params_handle) || - phi::dynload::fa3_fwd_params_get_arch(params_handle) < 90 || - (phi::dynload::fa3_fwd_params_get_page_table(params_handle) && - !phi::dynload::fa3_fwd_params_get_pagedkv_tma(params_handle)) || - phi::dynload::fa3_fwd_params_get_num_splits(params_handle) > 1), - true, - common::errors::InvalidArgument( - "This flash attention build does not support pack_gqa.")); -#endif -#ifdef FLASHATTENTION_DISABLE_PAGEDKV - PADDLE_ENFORCE_EQ( - (!(phi::dynload::fa3_fwd_params_get_page_table(params_handle) && - !phi::dynload::fa3_fwd_params_get_pagedkv_tma(params_handle))), - true, - common::errors::InvalidArgument( - "This flash attention build does not support paged KV.")); -#endif -#ifdef FLASHATTENTION_DISABLE_APPENDKV - PADDLE_ENFORCE_EQ( - !k_new_.is_initialized(), - true, - common::errors::InvalidArgument( - "This flash attention build does not support appending KV.")); -#endif - - if (total_q > 0 && - (total_k + dynload::fa3_fwd_params_get_total_knew(params_handle)) > 0 && - num_heads_k > 0) { - dynload::fa3_run_mha_fwd(params_handle, ctx.stream()); - if (dynload::fa3_fwd_params_get_num_splits(params_handle) > 1) { - if (out_type == phi::DataType::BFLOAT16) { - // Since we want output in BF16. Otherwise fwd_combine will output to - // FP16 - dynload::fa3_fwd_params_set_is_bf16(params_handle, true); - } - // Unless there's seqused_q, for the purpose of attn_combine, we can just - // treat it as batch=1 and seqlen = total_q, and don't need to dispatch to - // Varlen there. However, with dynamic split, each row needs to know which - // batch it belongs to to read the number of splits, so we just use the - // varlen version of combine kernel. if (is_varlen_q && - // !seqused_q_.has_value()) { if (is_varlen_q) { - // params.b = 1; - // params.seqlen_q = total_q; - // } - // } - dynload::fa3_run_mha_fwd_combine( - params_handle, ctx.stream(), true /*enable_pdl*/); - } - } else if (total_q > 0 && num_heads_k > 0) { - PADDLE_ENFORCE_EQ( - (out->dtype() == phi::DataType::BFLOAT16 || - out->dtype() == phi::DataType::FLOAT16 || - out->dtype() == phi::DataType::FLOAT8_E4M3FN), - true, - common::errors::InvalidArgument("flash attention 3 supports bfloat16, " - "float16 and float8_e4m3fn only.")); - // If seqlen_k == 0, then we have an empty tensor. We need to set the output - // to 0. - if (out->dtype() == phi::DataType::BFLOAT16) { - phi::funcs::SetConstant set_zero; - set_zero( - ctx, - out, - phi::dtype::bfloat16{0}); // If varlen we'll manually do the zero-ing - } else if (out->dtype() == phi::DataType::FLOAT16) { - phi::funcs::SetConstant set_zero; - set_zero( - ctx, - out, - phi::dtype::float16{0}); // If varlen we'll manually do the zero-ing - } else if (out->dtype() == phi::DataType::FLOAT8_E4M3FN) { - phi::funcs::SetConstant set_zero; - set_zero(ctx, - out, - phi::dtype::float8_e4m3fn{ - 0}); // If varlen we'll manually do the zero-ing - } - phi::funcs::SetConstant set_infinity; - set_infinity(ctx, softmax_lse, std::numeric_limits::infinity()); - } - -#else - RaiseNotSupportedError(); -#endif -} - -template -void FlashAttnV3Kernel(const Context &ctx, - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const paddle::optional &q_v_, - const paddle::optional &q_descale_, - const paddle::optional &k_descale_, - const paddle::optional &v_descale_, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - int num_splits, - const bool manual_set_pack_gqa, - const bool pack_gqa_, - const int sm_margin, - DenseTensor *out, - DenseTensor *softmax_lse) { -#ifdef PADDLE_WITH_FLASHATTN_V3 - // umiswing: the following options have not been fully tested - PADDLE_ENFORCE_EQ(q_v_.is_initialized(), - false, - common::errors::InvalidArgument("q_v_ is not supported")); - PADDLE_ENFORCE_EQ( - q_descale_.is_initialized(), - false, - common::errors::InvalidArgument("q_descale_ is not supported")); - PADDLE_ENFORCE_EQ( - k_descale_.is_initialized(), - false, - common::errors::InvalidArgument("k_descale_ is not supported")); - PADDLE_ENFORCE_EQ( - v_descale_.is_initialized(), - false, - common::errors::InvalidArgument("v_descale_ is not supported")); - PADDLE_ENFORCE_EQ( - window_size_left, - -1, - common::errors::InvalidArgument("window_size is not supported, please " - "set window_size_left/right to -1")); - PADDLE_ENFORCE_EQ( - window_size_right, - -1, - common::errors::InvalidArgument("window_size is not supported, please " - "set window_size_left/right to -1")); - PADDLE_ENFORCE_EQ(softcap, - 0, - common::errors::InvalidArgument( - "softcap is not supported, please set softcap to 0")); - PADDLE_ENFORCE_EQ( - num_splits, - 1, - common::errors::InvalidArgument( - "num_splits is not supported, please set num_splits to 1")); - PADDLE_ENFORCE_EQ(manual_set_pack_gqa, - false, - common::errors::InvalidArgument( - "manual_set_pack_gqa is not supported, please set " - "manual_set_pack_gqa to false")); - PADDLE_ENFORCE_EQ( - pack_gqa_, - false, - common::errors::InvalidArgument( - "pack_gqa_ is not supported, please set pack_gqa_ to false")); - PADDLE_ENFORCE_EQ( - sm_margin, - 0, - common::errors::InvalidArgument( - "sm_margin is not supported, please set sm_margin to 0")); - - DenseTensor out_accum; - DenseTensor softmax_lse_accum; - FlashAttnV3BaseKernel(ctx, - q, - k, - v, - paddle::none, // k_new_ - paddle::none, // v_new_ - q_v_, - paddle::none, // out_ - paddle::none, // cu_seqlens_q_ - paddle::none, // cu_seqlens_k_ - paddle::none, // cu_seqlens_k_new_ - paddle::none, // seqused_q_ - paddle::none, // seqused_k_ - paddle::none, // page_table_ - paddle::none, // kv_batch_idx_ - paddle::none, // leftpad_k_ - paddle::none, // rotary_cos_ - paddle::none, // rotary_sin_ - q_descale_, - k_descale_, - v_descale_, - paddle::none, // scheduler_metadata - 0, // max_seqlen_q_ - 0, // max_seqlen_k_ - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - true, // is_rotary_interleaved - num_splits, - manual_set_pack_gqa, - pack_gqa_, - sm_margin, - out, - softmax_lse, - &out_accum, - &softmax_lse_accum); -#else - RaiseNotSupportedError(); -#endif -} - -} // namespace phi - -PD_REGISTER_KERNEL(flash_attn_v3, - GPU, - ALL_LAYOUT, - phi::FlashAttnV3Kernel, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.h b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.h deleted file mode 100644 index 4cdf6398e87972..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -namespace phi { -template -void FlashAttnV3Kernel(const Context &ctx, - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const paddle::optional &q_v_, - const paddle::optional &q_descale_, - const paddle::optional &k_descale_, - const paddle::optional &v_descale_, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - int num_splits, - const bool manual_set_pack_gqa, - const bool pack_gqa_, - const int sm_margin, - DenseTensor *out, - DenseTensor *softmax_lse); -} // namespace phi diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_utils.cu b/paddle/phi/kernels/gpu/flash_attn_v3_utils.cu deleted file mode 100644 index cbfaeb8726642c..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_utils.cu +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gpu/flash_attn_v3_utils.h" -#include "paddle/phi/common/bfloat16.h" -namespace phi { -#ifdef PADDLE_WITH_FLASHATTN_V3 - -void destroy_flash_fwd_params_handle(Flash_fwd_params *params_handle) { - phi::dynload::fa3_destroy_fwd_params_handle(params_handle); -} - -void destroy_flash_bwd_params_handle(Flash_bwd_params *params_handle) { - phi::dynload::fa3_destroy_bwd_params_handle(params_handle); -} - -// umiswing: no singleton, the details of Flash_fwd_params and Flash_bwd_params -// are encapsulated within libflashattnv3.so to ensure abi compatibility, only -// opaque pointers are exposed to phi -Flash_fwd_params *get_flash_fwd_params_handle() { - static std::unique_ptr - params_handle(phi::dynload::fa3_create_fwd_params_handle(), - &destroy_flash_fwd_params_handle); - - return params_handle.get(); -} - -Flash_bwd_params *get_flash_bwd_params_handle() { - static std::unique_ptr - params_handle(phi::dynload::fa3_create_bwd_params_handle(), - &destroy_flash_bwd_params_handle); - - return params_handle.get(); -} - -void set_params_fprop(Flash_fwd_params *params_handle, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const DenseTensor *out, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *softmax_lse_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const gpuDeviceProp &dprops, - const float softcap, - const int sm_margin) { - dynload::fa3_fwd_params_set_is_bf16(params_handle, - q.dtype() == phi::DataType::BFLOAT16); - dynload::fa3_fwd_params_set_is_e4m3( - params_handle, q.dtype() == phi::DataType::FLOAT8_E4M3FN); - - // Set the pointers and strides. - dynload::fa3_fwd_params_set_q_ptr(params_handle, - const_cast(q.data())); - dynload::fa3_fwd_params_set_k_ptr(params_handle, - const_cast(k.data())); - dynload::fa3_fwd_params_set_v_ptr(params_handle, - const_cast(v.data())); - // All stride are in elements, not bytes. - dynload::fa3_fwd_params_set_q_row_stride(params_handle, - q.strides()[q.strides().size() - 3]); - dynload::fa3_fwd_params_set_k_row_stride(params_handle, - k.strides()[k.strides().size() - 3]); - dynload::fa3_fwd_params_set_v_row_stride(params_handle, - v.strides()[v.strides().size() - 3]); - dynload::fa3_fwd_params_set_q_head_stride( - params_handle, q.strides()[q.strides().size() - 2]); - dynload::fa3_fwd_params_set_k_head_stride( - params_handle, k.strides()[k.strides().size() - 2]); - dynload::fa3_fwd_params_set_v_head_stride( - params_handle, v.strides()[v.strides().size() - 2]); - dynload::fa3_fwd_params_set_v_dim_stride(params_handle, - v.strides()[v.strides().size() - 1]); - dynload::fa3_fwd_params_set_o_ptr(params_handle, - const_cast(out->data())); - dynload::fa3_fwd_params_set_o_row_stride( - params_handle, out->strides()[out->strides().size() - 3]); - dynload::fa3_fwd_params_set_o_head_stride( - params_handle, out->strides()[out->strides().size() - 2]); - - if (cu_seqlens_q_d == nullptr) { - dynload::fa3_fwd_params_set_q_batch_stride(params_handle, q.strides()[0]); - dynload::fa3_fwd_params_set_o_batch_stride(params_handle, - out->strides()[0]); - } - if (cu_seqlens_k_d == nullptr) { - dynload::fa3_fwd_params_set_k_batch_stride(params_handle, k.strides()[0]); - dynload::fa3_fwd_params_set_v_batch_stride(params_handle, v.strides()[0]); - } - - dynload::fa3_fwd_params_set_cu_seqlens_q(params_handle, - static_cast(cu_seqlens_q_d)); - dynload::fa3_fwd_params_set_cu_seqlens_k(params_handle, - static_cast(cu_seqlens_k_d)); - dynload::fa3_fwd_params_set_seqused_q(params_handle, - static_cast(seqused_q)); - dynload::fa3_fwd_params_set_seqused_k(params_handle, - static_cast(seqused_k)); - - // Softmax sum - dynload::fa3_fwd_params_set_softmax_lse_ptr(params_handle, softmax_lse_d); - - // Set the dimensions. - dynload::fa3_fwd_params_set_b(params_handle, b); - dynload::fa3_fwd_params_set_h(params_handle, h); - dynload::fa3_fwd_params_set_h_k(params_handle, h_k); - dynload::fa3_fwd_params_set_seqlen_q(params_handle, seqlen_q); - dynload::fa3_fwd_params_set_seqlen_k(params_handle, seqlen_k); - dynload::fa3_fwd_params_set_seqlen_q_rounded(params_handle, seqlen_q_rounded); - dynload::fa3_fwd_params_set_seqlen_k_rounded(params_handle, seqlen_k_rounded); - dynload::fa3_fwd_params_set_d(params_handle, d); - dynload::fa3_fwd_params_set_d_rounded(params_handle, d_rounded); - - // Set the different scale values. - dynload::fa3_fwd_params_set_scale_softmax(params_handle, softmax_scale); - dynload::fa3_fwd_params_set_softcap(params_handle, softcap); - - // Set this to probability of keeping an element to simplify things. - dynload::fa3_fwd_params_set_p_dropout(params_handle, 1.f - p_dropout); - // Convert p from float to int so we don't have to convert the random uint to - // float to compare. [Minor] We want to round down since when we do the - // comparison we use <= instead of < params.p_dropout_in_uint = - // uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * - // 65535.0)); - dynload::fa3_fwd_params_set_p_dropout_in_uint8_t( - params_handle, - uint8_t(std::floor(dynload::fa3_fwd_params_get_p_dropout(params_handle) * - 255.0))); - dynload::fa3_fwd_params_set_rp_dropout( - params_handle, - 1.f / dynload::fa3_fwd_params_get_p_dropout(params_handle)); - PADDLE_ENFORCE_LT( - p_dropout, - 1.f, - common::errors::InvalidArgument("p_dropout must less than 1")); - - PADDLE_ENFORCE_EQ( - p_dropout, - 0.0f, - common::errors::InvalidArgument( - "This flash attention build does not support dropout.")); - - // Causal is the special case where window_size_right == 0 and - // window_size_left < 0. Local is the more general case where - // window_size_right >= 0 or window_size_left >= 0. - dynload::fa3_fwd_params_set_is_causal( - params_handle, window_size_left < 0 && window_size_right == 0); - dynload::fa3_fwd_params_set_is_local( - params_handle, - (window_size_left >= 0 || window_size_right >= 0) && - !dynload::fa3_fwd_params_get_is_causal(params_handle)); - - // TODO(tridao): check this - if (window_size_left < 0 && window_size_right >= 0) { - window_size_left = seqlen_k - 1; - } - if (window_size_left >= 0 && window_size_right < 0) { - window_size_right = seqlen_q - 1; - } - dynload::fa3_fwd_params_set_window_size_left(params_handle, window_size_left); - dynload::fa3_fwd_params_set_window_size_right(params_handle, - window_size_right); - - int arch = dprops.major * 10 + dprops.minor; - int num_sm = dprops.multiProcessorCount - sm_margin; - - dynload::fa3_fwd_params_set_arch(params_handle, arch); - dynload::fa3_fwd_params_set_num_sm(params_handle, num_sm); - -#ifdef FLASHATTENTION_DISABLE_LOCAL - PADDLE_ENFORCE_EQ( - !dynload::fa3_fwd_params_get_is_local(params_handle), - true, - common::errors::InvalidArgument( - "This flash attention build does not support local attention.")); -#endif -} - -void set_params_dgrad(Flash_bwd_params *params_handle, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dq, - DenseTensor *dk, - DenseTensor *dv, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *dq_accum_d, - void *dk_accum_d, - void *dv_accum_d, - void *softmax_lse_d, - void *dsoftmax_sum_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const gpuDeviceProp &dprops, - const float softcap, - bool deterministic, - int const sm_margin) { - set_params_fprop(dynload::fa3_cast_to_fwd_params_handle(params_handle), - b, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - seqlen_k_rounded, - h, - h_k, - d, - d_rounded, - q, - k, - v, - &out, - cu_seqlens_q_d, - cu_seqlens_k_d, - seqused_q, - seqused_k, - softmax_lse_d, - p_dropout, - softmax_scale, - window_size_left, - window_size_right, - dprops, - softcap, - sm_margin); - - // Set the pointers and strides. - dynload::fa3_bwd_params_set_do_ptr(params_handle, - const_cast(dout.data())); - dynload::fa3_bwd_params_set_do_row_stride( - params_handle, dout.strides()[dout.strides().size() - 3]); - dynload::fa3_bwd_params_set_do_head_stride( - params_handle, dout.strides()[dout.strides().size() - 2]); - dynload::fa3_bwd_params_set_dq_ptr(params_handle, dq->data()); - dynload::fa3_bwd_params_set_dk_ptr(params_handle, dk->data()); - dynload::fa3_bwd_params_set_dv_ptr(params_handle, dv->data()); - dynload::fa3_bwd_params_set_dq_row_stride( - params_handle, dq->strides()[dq->strides().size() - 3]); - dynload::fa3_bwd_params_set_dk_row_stride( - params_handle, dk->strides()[dk->strides().size() - 3]); - dynload::fa3_bwd_params_set_dv_row_stride( - params_handle, dv->strides()[dv->strides().size() - 3]); - dynload::fa3_bwd_params_set_dq_head_stride( - params_handle, dq->strides()[dq->strides().size() - 2]); - dynload::fa3_bwd_params_set_dk_head_stride( - params_handle, dk->strides()[dk->strides().size() - 2]); - dynload::fa3_bwd_params_set_dv_head_stride( - params_handle, dv->strides()[dv->strides().size() - 2]); - - if (cu_seqlens_q_d == nullptr) { - dynload::fa3_bwd_params_set_do_batch_stride(params_handle, - dout.strides()[0]); - dynload::fa3_bwd_params_set_dq_batch_stride(params_handle, - dq->strides()[0]); - dynload::fa3_bwd_params_set_dk_batch_stride(params_handle, - dk->strides()[0]); - dynload::fa3_bwd_params_set_dv_batch_stride(params_handle, - dv->strides()[0]); - } - - dynload::fa3_bwd_params_set_dq_accum_ptr(params_handle, dq_accum_d); - dynload::fa3_bwd_params_set_dk_accum_ptr(params_handle, dk_accum_d); - dynload::fa3_bwd_params_set_dv_accum_ptr(params_handle, dv_accum_d); - - // Softmax sum - dynload::fa3_bwd_params_set_dsoftmax_sum(params_handle, dsoftmax_sum_d); - - dynload::fa3_bwd_params_set_deterministic(params_handle, deterministic); -} - -#endif -} // namespace phi diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_utils.h b/paddle/phi/kernels/gpu/flash_attn_v3_utils.h deleted file mode 100644 index 59c5fe363feb1a..00000000000000 --- a/paddle/phi/kernels/gpu/flash_attn_v3_utils.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_FLASHATTN_V3 -#include "paddle/phi/backends/dynload/flashattnv3.h" -#endif -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/platform/device_context.h" - -namespace phi { -#ifdef PADDLE_WITH_FLASHATTN_V3 - -#define CHECK_DEVICE(x) \ - PADDLE_ENFORCE_EQ( \ - x.place().GetType(), \ - phi::AllocationType::GPU, \ - common::errors::InvalidArgument(#x " must be on CUDA Device")) - -#define CHECK_SHAPE(x, ...) \ - PADDLE_ENFORCE_EQ(x.dims(), \ - common::make_ddim({__VA_ARGS__}), \ - common::errors::InvalidArgument( \ - #x " must have shape (" #__VA_ARGS__ ")")) - -#define CHECK_CONTIGUOUS(x) \ - PADDLE_ENFORCE_EQ(x.meta().is_contiguous(), \ - true, \ - common::errors::InvalidArgument(#x " must be contiguous")) - -Flash_fwd_params *get_flash_fwd_params_handle(); - -Flash_bwd_params *get_flash_bwd_params_handle(); - -inline int get_max_headdim() { -#ifndef FLASHATTENTION_DISABLE_HDIM256 - return 256; -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM192 - return 192; -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM128 - return 128; -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM96 - return 96; -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM64 - return 64; -#endif - return 0; -} - -inline int round_up_headdim(int head_size) { -#ifndef FLASHATTENTION_DISABLE_HDIM64 - if (head_size <= 64) { - return 64; - } -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM96 - if (head_size <= 96) { - return 96; - } -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM128 - if (head_size <= 128) { - return 128; - } -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM192 - if (head_size <= 192) { - return 192; - } -#endif -#ifndef FLASHATTENTION_DISABLE_HDIM256 - if (head_size <= 256) { - return 256; - } -#endif - return 256; -} - -void set_params_fprop(Flash_fwd_params *params_handle, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const DenseTensor *out, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *softmax_lse_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const gpuDeviceProp &dprops, - const float softcap = 0.f, - const int sm_margin = 0); - -void set_params_dgrad(Flash_bwd_params *params_handle, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const DenseTensor &q, - const DenseTensor &k, - const DenseTensor &v, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dq, - DenseTensor *dk, - DenseTensor *dv, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *dq_accum_d, - void *dk_accum_d, - void *dv_accum_d, - void *softmax_lse_d, - void *dsoftmax_sum_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const gpuDeviceProp &dprops, - const float softcap = 0.f, - bool deterministic = false, - int const sm_margin = 0); -#endif - -} // namespace phi diff --git a/paddle/phi/kernels/gpu/nonzero_kernel.cu b/paddle/phi/kernels/gpu/nonzero_kernel.cu index a79173e6f15bd1..ba64aefd2c9fca 100644 --- a/paddle/phi/kernels/gpu/nonzero_kernel.cu +++ b/paddle/phi/kernels/gpu/nonzero_kernel.cu @@ -80,6 +80,10 @@ void RestrictNonZeroKernel(const Context &dev_ctx, DenseTensor *out) { DenseTensor in_data; auto dims = condition.dims(); + if (condition.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } using Functor = IndexFunctor; Functor index_functor{dims}; diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index 624ad5cabc4723..f49828065bb692 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -41,6 +41,15 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, errors::PreconditionNotMet( "PutAlongAxisGradOpCUDAKernel only runs on GPU.")); + if (x.numel() == 0) { + if (x_grad) { + dev_ctx.template Alloc(x_grad); + } + if (value_grad) { + dev_ctx.template Alloc(value_grad); + } + return; + } const auto& index_type = index.dtype(); if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 8ab6b8b5e6f28b..3ef61ef57f19e0 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/common/flags.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/slice_kernel.h" @@ -46,6 +47,16 @@ void SliceGradStridedKernel(const Context& dev_ctx, })); DenseTensor tmp; tmp.set_meta(out_grad.meta()); + if (out_grad.numel() == 0) { + // set zero to input_grad + + PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { + phi::StridedTensorFill( + *input_grad, 0, input_grad); + })); + + return; + } SliceStridedKernel(dev_ctx, *input_grad, axes, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index ce4442de5aa5e7..72d84ca1afa212 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1156,17 +1156,6 @@ func : flash_attn_unpadded_grad data_type: q -- backward_op : flash_attn_v3_grad - forward : flash_attn_v3 (Tensor q, Tensor k, Tensor v, Tensor q_v_, Tensor q_descale_, Tensor k_descale_, Tensor v_descale_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa_, int sm_margin) -> Tensor(out), Tensor(softmax_lse) - args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor out_grad, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, int sm_margin) - output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) - infer_meta : - func : FlashAttnV3GradInferMeta - param : [q, k, v] - kernel : - func : flash_attn_v3_grad - data_type : q - - backward_op : flash_attn_varlen_qkvpacked_grad forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, Scalar max_seqlen_q, Scalar max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index f4892cb9b9e91a..e72119430cc913 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1991,18 +1991,6 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad -- op : flash_attn_v3 - args : (Tensor q, Tensor k, Tensor v, Tensor q_v_, Tensor q_descale_, Tensor k_descale_, Tensor v_descale_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa_, int sm_margin) - output : Tensor(out), Tensor(softmax_lse) - optional : q_v_, q_descale_, k_descale_, v_descale_ - infer_meta : - func : FlashAttnV3InferMeta - param : [q, k, v] - kernel : - func : flash_attn_v3 - data_type : q - backward : flash_attn_v3_grad - - op : flash_attn_varlen_qkvpacked args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, Scalar max_seqlen_q, Scalar max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) diff --git a/python/paddle/base/layer_helper_base.py b/python/paddle/base/layer_helper_base.py index 22a385dc601aed..b0720a048647c4 100644 --- a/python/paddle/base/layer_helper_base.py +++ b/python/paddle/base/layer_helper_base.py @@ -21,7 +21,6 @@ import paddle from . import core, unique_name -from .data_feeder import convert_dtype from .framework import ( Variable, _current_expected_place, @@ -368,7 +367,7 @@ def create_parameter( if not dtype: dtype = self.__dtype if isinstance(dtype, core.DataType): - dtype = convert_dtype(dtype) + dtype = paddle.pir.core.datatype_to_vartype[dtype] if is_bias: suffix = 'b' default_initializer = ( @@ -449,6 +448,8 @@ def create_parameter( ) else: if in_pir_mode(): + if isinstance(dtype, core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] return paddle.pir.core.create_parameter( dtype=dtype, shape=shape, diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 053c0d9a2ba055..08995369cbb86c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -746,6 +746,14 @@ def fused_allreduce(*_): return fused_allreduce + def _increase_comm_buffers_acc_steps(self, increment): + for buffer in self._comm_buffer_list: + buffer._acc_steps += increment + + def _reset_comm_buffers_acc_steps(self, acc_steps): + for buffer in self._comm_buffer_list: + buffer._acc_steps = acc_steps + def _build_comm_buffers( self, acc_steps, group_size=256 * 1024 * 1024, free_grads_in_comm=False ): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 5511d6438501ad..5c063aa48e41d3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -2760,6 +2760,9 @@ def backward_async_comm( # reset dynamic meta counter if self._dynamic_shape: + assert self._p2p_helper._dynamic_cnt == len( + self._p2p_helper._send_recv_meta_list + ), "p2p dynamic_cnt should equal to send_recv_meta_list" self._p2p_helper._dynamic_cnt = 0 return train_loss @@ -3359,6 +3362,13 @@ def forward_backward_pipeline( backward_send_recv_buffer_queue.empty() ), "send_recv buffer should be empty" + # reset dynamic meta counter + if self._dynamic_shape: + assert self._p2p_helper._dynamic_cnt == len( + self._p2p_helper._send_recv_meta_list + ), "p2p dynamic_cnt should equal to send_recv_meta_list" + self._p2p_helper._dynamic_cnt = 0 + self._flush_records() self._sync_overlap_grads() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 6aeedd1f6471cd..8dd7c613b6512d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -763,6 +763,7 @@ def recv_forward(self, pp_first_stage, sync_recv=True, batch_p2p_comm=True): sync_recv=sync_recv, send_recv_meta=self._send_recv_meta, batch_p2p_comm=batch_p2p_comm, + dynamic_shape=self._dynamic_shape, ) if self._dynamic_shape: self._dynamic_cnt += 1 @@ -807,9 +808,6 @@ def recv_backward( if _timers is not None: _timers("recv_backward").stop() - if self._dynamic_shape and need_increase_cnt: - self._dynamic_cnt += 1 - return output_tensor_grad def send_forward( @@ -823,10 +821,6 @@ def send_forward( if _timers is not None: _timers("send_forward").start() - assert ( - not self._dynamic_shape - ), "p2p_helper.send_forward function doesn't support dynamic_shape now" - if not pp_last_stage: self._send_meta(output_tensor, skip_check_meta=skip_check_meta) _p2p_helper( @@ -836,7 +830,11 @@ def send_forward( recv_next=False, send_recv_meta=self._send_recv_meta, batch_p2p_comm=batch_p2p_comm, + dynamic_shape=self._dynamic_shape, ) + if self._dynamic_shape: + self._dynamic_cnt += 1 + if _timers is not None: _timers("send_forward").stop() @@ -847,11 +845,9 @@ def send_backward( if _timers is not None: _timers("send_backward").start() - assert ( - not self._dynamic_shape - ), "p2p_helper.send_backward function doesn't support dynamic_shape now" - if not pp_first_stage: + if self._dynamic_shape: + self._send_meta(input_tensor_grad, reverse=True) _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, @@ -859,7 +855,10 @@ def send_backward( recv_next=False, send_recv_meta=self._send_recv_meta, batch_p2p_comm=batch_p2p_comm, + dynamic_shape=self._dynamic_shape, ) + if self._dynamic_shape: + self._dynamic_cnt += 1 if _timers is not None: _timers("send_backward").stop() diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index e34fa4e71608da..9879bc874f55e7 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -20,6 +20,7 @@ from paddle.base.layer_helper import LayerHelper from paddle.framework import ( in_dynamic_mode, + in_dynamic_or_pir_mode, in_pir_mode, ) from paddle.tensor.linalg import matmul @@ -71,11 +72,7 @@ def fused_matmul_bias( """ if bias is None: return matmul(x, y, transpose_x, transpose_y, name) - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y - ) - if in_pir_mode(): + if in_dynamic_or_pir_mode(): out, _ = _C_ops.fused_gemm_epilogue( x, y, bias, transpose_x, transpose_y, "none" ) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 598bb198705562..4ccd12e106c508 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -414,7 +414,6 @@ def flash_attention( rng_name="", training=True, name=None, - softmax_scale=None, ): r""" The equation is: @@ -485,84 +484,20 @@ def flash_attention( sdp_func_name = _select_sdp(head_dim) if sdp_func_name == "flash_attn": - if "xpu" in paddle.get_device(): - fa_version = 2 - elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[ - "FLAGS_cudnn_deterministic" - ]: - fa_version = 2 - else: - fa_version = paddle.base.framework.get_flags( - ["FLAGS_flash_attn_version"] - )["FLAGS_flash_attn_version"] - assert ( - in_dynamic_or_pir_mode() or fa_version == 2 - ), "flash attention 3 only support dynamic or pir mode" - assert ( - dropout == 0.0 or fa_version == 2 - ), "flash attention 3 does not support dropout" - assert ( - not return_softmax or fa_version == 2 - ), "flash attention 3 does not support return softmax" - assert ( - fixed_seed_offset is None or fa_version == 2 - ), "flash attention 3 does not support return softmax" - assert ( - rng_name == "" or fa_version == 2 - ), "flash attention 3 does not support setting rng_name" - assert ( - training or fa_version == 2 - ), "flash attention 3 does not support setting training" - assert ( - name is None or fa_version == 2 - ), "flash attention 3 does not support setting name" - assert ( - softmax_scale is None or fa_version == 3 - ), "flash attention 2 does not support setting softmax_scale" if in_dynamic_or_pir_mode(): - if fa_version == 2: - (result_attention, result_softmax, _, _) = _C_ops.flash_attn( - query, - key, - value, - fixed_seed_offset, - None, - dropout, - causal, - return_softmax, - not training, - rng_name, - ) - return result_attention, ( - result_softmax if return_softmax else None - ) - elif fa_version == 3: - if softmax_scale is None: - softmax_scale = query.shape[-1] ** (-0.5) - - out, softmax_lse = _C_ops.flash_attn_v3( - query, - key, - value, - None, # q_v_ - None, # q_descale_ - None, # k_descale_ - None, # v_descale_ - softmax_scale, - causal, - -1, # window_size_left - -1, # window_size_right - 0.0, # softcap - 1, # num_splits - False, # manual_set_pack_gqa - False, # pack_gqa_ - 0, # sm_margin - ) - return out, None # return_softmax - else: - raise ValueError( - f"Invalid flash attention version: {fa_version}" - ) + (result_attention, result_softmax, _, _) = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + None, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) + return result_attention, result_softmax if return_softmax else None helper = LayerHelper('flash_attn', **locals()) dtype = helper.input_dtype(input_param_name='q') diff --git a/third_party/flashattn b/third_party/flashattn index 7c0f9623bc710b..b459f7deb627a8 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit 7c0f9623bc710bf1fd1a8bd17bdccc68cf2237d5 +Subproject commit b459f7deb627a8e3a61649907d66da9ed87233a2