Skip to content

Cherr pick fix fused gemm api #72487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ else()
-DCMAKE_JOB_POOLS:STRING=compile=${FA_JOB_POOLS_COMPILE}
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
-DWITH_FLASHATTN_V3=${WITH_FLASHATTN_V3}
-DDISABLE_FP8=ON # umiswing: disable FP8, SM8x and PACKGQA on FA3
-DDISABLE_SM8X=ON
-DDISABLE_PACKGQA=ON
-DSKIP_BUILD_FA=${SKIP_BUILD_FA}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
Expand Down
223 changes: 6 additions & 217 deletions paddle/phi/backends/dynload/flashattnv3.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace phi {
namespace dynload {

extern std::once_flag flashattnv3_dso_flag;
extern void *flashattnv3_dso_handle;
extern void* flashattnv3_dso_handle;

#define DYNAMIC_LOAD_FLASHATTN_V3_WRAP(__name) \
struct DynLoad__##__name { \
Expand All @@ -34,7 +34,7 @@ extern void *flashattnv3_dso_handle;
std::call_once(flashattnv3_dso_flag, []() { \
flashattnv3_dso_handle = phi::dynload::GetFlashAttnV3DsoHandle(); \
}); \
static void *p_##__name = dlsym(flashattnv3_dso_handle, #__name); \
static void* p_##__name = dlsym(flashattnv3_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
Expand All @@ -45,223 +45,12 @@ extern void *flashattnv3_dso_handle;

#ifdef PADDLE_WITH_CUDA
#define FLASHATTN_V3_ROUTINE_EACH(__macro) \
__macro(fa3_create_fwd_params_handle); \
__macro(fa3_clear_fwd_params_handle); \
__macro(fa3_destroy_fwd_params_handle); \
__macro(fa3_create_bwd_params_handle); \
__macro(fa3_clear_bwd_params_handle); \
__macro(fa3_destroy_bwd_params_handle); \
__macro(fa3_cast_to_fwd_params_handle); \
__macro(fa3_run_mha_fwd_combine); \
__macro(fa3_run_mha_fwd); \
__macro(fa3_run_mha_bwd); \
__macro(fa3_get_pagedkv_tma); \
__macro(fa3_get_pack_gqa); \
__macro(fa3_get_num_splits);

FLASHATTN_V3_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP)

#define FLASHATTN_V3_HANDLE_ROUTINE(member) \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_fwd_params_get_##member); \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_fwd_params_set_##member); \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_get_##member); \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_set_##member);

// The QKV matrices.
FLASHATTN_V3_HANDLE_ROUTINE(q_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(k_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(v_ptr)

// The stride between rows of the Q, K and V matrices.
FLASHATTN_V3_HANDLE_ROUTINE(q_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(k_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(q_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(k_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(q_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(k_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_dim_stride)

// The number of heads.
FLASHATTN_V3_HANDLE_ROUTINE(h)
FLASHATTN_V3_HANDLE_ROUTINE(h_k)

// The O matrix (output).
FLASHATTN_V3_HANDLE_ROUTINE(o_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(oaccum_ptr)

// The stride between rows of O.
FLASHATTN_V3_HANDLE_ROUTINE(o_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(o_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(o_head_stride)

// The pointer to the softmax sum.
FLASHATTN_V3_HANDLE_ROUTINE(softmax_lse_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(softmax_lseaccum_ptr)

// For FP8 scaling
FLASHATTN_V3_HANDLE_ROUTINE(q_descale_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(k_descale_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(v_descale_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(q_descale_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(q_descale_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(k_descale_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(k_descale_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_descale_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(v_descale_head_stride)

// The dimensions.
FLASHATTN_V3_HANDLE_ROUTINE(b)
FLASHATTN_V3_HANDLE_ROUTINE(seqlen_q)
FLASHATTN_V3_HANDLE_ROUTINE(seqlen_k)
FLASHATTN_V3_HANDLE_ROUTINE(seqlen_knew)
FLASHATTN_V3_HANDLE_ROUTINE(d)
FLASHATTN_V3_HANDLE_ROUTINE(seqlen_q_rounded)
FLASHATTN_V3_HANDLE_ROUTINE(seqlen_k_rounded)
FLASHATTN_V3_HANDLE_ROUTINE(d_rounded)
FLASHATTN_V3_HANDLE_ROUTINE(rotary_dim)
FLASHATTN_V3_HANDLE_ROUTINE(total_q)
FLASHATTN_V3_HANDLE_ROUTINE(total_k)
FLASHATTN_V3_HANDLE_ROUTINE(total_knew)
FLASHATTN_V3_HANDLE_ROUTINE(b_k)
FLASHATTN_V3_HANDLE_ROUTINE(dv)
FLASHATTN_V3_HANDLE_ROUTINE(dv_rounded)

// The scaling factors for the kernel.
FLASHATTN_V3_HANDLE_ROUTINE(scale_softmax)
FLASHATTN_V3_HANDLE_ROUTINE(softcap)

// array of length b+1 holding starting offset of each sequence.
FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_q)
FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_k)
FLASHATTN_V3_HANDLE_ROUTINE(cu_seqlens_knew)
FLASHATTN_V3_HANDLE_ROUTINE(leftpad_k)

// If provided, the actual length of each q/k sequence.
FLASHATTN_V3_HANDLE_ROUTINE(seqused_q)
FLASHATTN_V3_HANDLE_ROUTINE(seqused_k)

// The stride between rows of Oaccum.
FLASHATTN_V3_HANDLE_ROUTINE(oaccum_split_stride)
FLASHATTN_V3_HANDLE_ROUTINE(oaccum_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(oaccum_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(oaccum_head_stride)

// The stride between rows of LSEaccum.
FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_split_stride)
FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(lseaccum_head_stride)

// The K_new and V_new matrices.
FLASHATTN_V3_HANDLE_ROUTINE(knew_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(vnew_ptr)

// The stride between rows of the Q, K and V matrices.
FLASHATTN_V3_HANDLE_ROUTINE(knew_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(vnew_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(knew_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(vnew_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(knew_head_stride)
FLASHATTN_V3_HANDLE_ROUTINE(vnew_head_stride)

FLASHATTN_V3_HANDLE_ROUTINE(qv_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(qv_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(qv_row_stride)
FLASHATTN_V3_HANDLE_ROUTINE(qv_head_stride)

// The cos and sin matrices for rotary embedding.
FLASHATTN_V3_HANDLE_ROUTINE(rotary_cos_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(rotary_sin_ptr)

// The indices to index into the KV cache.
FLASHATTN_V3_HANDLE_ROUTINE(kv_batch_idx)

// Paged KV cache
FLASHATTN_V3_HANDLE_ROUTINE(page_table)
FLASHATTN_V3_HANDLE_ROUTINE(page_table_batch_stride)
FLASHATTN_V3_HANDLE_ROUTINE(page_size)
FLASHATTN_V3_HANDLE_ROUTINE(num_pages)
FLASHATTN_V3_HANDLE_ROUTINE(pagedkv_tma)

// The dropout probability (probability of keeping an activation).
FLASHATTN_V3_HANDLE_ROUTINE(p_dropout)
FLASHATTN_V3_HANDLE_ROUTINE(p_dropout_in_uint8_t)

// Scale factor of 1 / (1 - p_dropout).
FLASHATTN_V3_HANDLE_ROUTINE(rp_dropout)

// Local window size
FLASHATTN_V3_HANDLE_ROUTINE(window_size_left)
FLASHATTN_V3_HANDLE_ROUTINE(window_size_right)

// Pointer to the RNG seed (idx 0) and offset (idx 1).
FLASHATTN_V3_HANDLE_ROUTINE(rng_state)

FLASHATTN_V3_HANDLE_ROUTINE(is_bf16)
FLASHATTN_V3_HANDLE_ROUTINE(is_fp32)
FLASHATTN_V3_HANDLE_ROUTINE(is_e4m3)
FLASHATTN_V3_HANDLE_ROUTINE(is_causal)
FLASHATTN_V3_HANDLE_ROUTINE(is_local)

FLASHATTN_V3_HANDLE_ROUTINE(is_rotary_interleaved)

FLASHATTN_V3_HANDLE_ROUTINE(num_splits) // For split-KV version
FLASHATTN_V3_HANDLE_ROUTINE(pack_gqa)

FLASHATTN_V3_HANDLE_ROUTINE(tile_count_semaphore)
FLASHATTN_V3_HANDLE_ROUTINE(num_splits_dynamic_ptr)
FLASHATTN_V3_HANDLE_ROUTINE(skip_scheduler_metadata_computation)

FLASHATTN_V3_HANDLE_ROUTINE(arch)
FLASHATTN_V3_HANDLE_ROUTINE(num_sm)

#define FLASHATTN_V3_BWD_HANDLE_ROUTINE(type, member) \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_get_##member); \
DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(fa3_bwd_params_set_##member);

// The dO and dQKV matrices.
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, do_ptr)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dq_ptr)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dk_ptr)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dv_ptr)

// To accumulate dQ
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dq_accum_ptr)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dk_accum_ptr)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dv_accum_ptr)

// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;

// The stride between rows of the dO, dQ, dK and dV matrices.
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_batch_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_row_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, do_head_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_batch_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_batch_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_batch_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_row_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_row_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_row_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_head_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dk_head_stride)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dv_head_stride)

// The pointer to the softmax d sum.
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, dsoftmax_sum)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(void *, softmax_lse_log2_ptr)

FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dq_semaphore)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dk_semaphore)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int *, dv_semaphore)

FLASHATTN_V3_BWD_HANDLE_ROUTINE(bool, deterministic)
FLASHATTN_V3_BWD_HANDLE_ROUTINE(int64_t, dq_accum_split_stride)
__macro(flash_attn_v3_fwd); \
__macro(flash_attn_v3_bwd);
#endif

FLASHATTN_V3_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP);

#undef DYNAMIC_LOAD_FLASHATTN_V3_WRAP

} // namespace dynload
Expand Down
17 changes: 0 additions & 17 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,23 +348,6 @@ void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) {
}
}

void FlashAttnV3GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
if (dq) {
dq->share_meta(q);
}
if (dk) {
dk->share_meta(k);
}
if (dv) {
dv->share_meta(v);
}
}

void Flatten2GradInferMeta(const MetaTensor& x,
const MetaTensor& x_shape,
const MetaTensor& out_grad,
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,6 @@ void FlashAttnGradInferMeta(const MetaTensor& q,

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq);

void FlashAttnV3GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);

void Flatten2GradInferMeta(const MetaTensor& x,
const MetaTensor& x_shape,
const MetaTensor& out_grad,
Expand Down
31 changes: 0 additions & 31 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,37 +726,6 @@ void CalcReducedAttnScoresInferMeta(const MetaTensor& q,
reduced_scores->set_dims({batch_size, num_heads, 1, seqlen_k});
}

void FlashAttnV3InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse) {
// TODO(umiswing): support varlen
constexpr bool is_varlen_q = false;
auto const sizes = q.dims();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
int num_heads = q.dims()[q.dims().size() - 2];
int const head_size_v = v.dims()[v.dims().size() - 1];
auto q_type = q.dtype();
auto out_type =
q_type == phi::DataType::FLOAT8_E4M3FN ? phi::DataType::BFLOAT16 : q_type;
if (!is_varlen_q) {
out->set_dims({batch_size, seqlen_q, num_heads, head_size_v});
} else {
// TODO(umiswing): support varlen
}

out->set_dtype(out_type);

if (!is_varlen_q) {
softmax_lse->set_dims({batch_size, num_heads, seqlen_q});
} else {
// TODO(umiswing): support varlen
}
softmax_lse->set_dtype(phi::DataType::FLOAT32);
}

void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
Expand Down
6 changes: 0 additions & 6 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,6 @@ void CalcReducedAttnScoresInferMeta(const MetaTensor& q,
const MetaTensor& softmax_lse,
MetaTensor* reduced_scores);

void FlashAttnV3InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse);

void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"

namespace phi {
Expand All @@ -34,6 +35,21 @@ void FusedGemmEpilogueGradKernel(
DenseTensor* x_grad,
DenseTensor* y_grad,
DenseTensor* bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(y_grad);
phi::FullKernel<T>(
dev_ctx, common::vectorize(y.dims()), 0.0, y.dtype(), y_grad);

if (bias_grad) {
dev_ctx.template Alloc<T>(bias_grad);
phi::FullKernel<T>(dev_ctx,
common::vectorize(bias_grad->dims()),
0.0,
bias_grad->dtype(),
bias_grad);
}
return;
}
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060
PADDLE_THROW(common::errors::Unimplemented(
"The fused_gemm_epilogue operator only support CUDA 11.6 "
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx,
const std::string& activation,
DenseTensor* out,
DenseTensor* reserve_space) {
if (out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060
PADDLE_THROW(common::errors::Unimplemented(
"The fused_gemm_epilogue operator only support CUDA 11.6 "
Expand Down
Loading
Loading