Skip to content

Commit 46e042d

Browse files
authored
dcu cutlass fa (#71337)
1 parent 9ef93d1 commit 46e042d

File tree

3 files changed

+68
-218
lines changed

3 files changed

+68
-218
lines changed

cmake/external/flashattn.cmake

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,30 @@ if(WITH_ROCM)
3232
set(FLASHATTN_LIBRARIES
3333
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
3434
CACHE FILEPATH "flash-attn Library" FORCE)
35-
36-
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
37-
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
38-
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
39-
set(FLASHATTN_CXX_FLAGS
40-
"${CMAKE_CXX_FLAGS} -w -Wno-deprecated-builtins -Wno-deprecated -DNDEBUG -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -fPIC -O3 -std=c++17 -D__HIP_PLATFORM_HCC__=1 --offload-arch=gfx928 -D__gfx940__ -mllvm -enable-num-vgprs-512=true"
41-
)
42-
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
43-
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
35+
set(FA_BUILD_DIR "${FLASHATTN_PREFIX_DIR}/src/extern_flashattn-build/")
4436

4537
ExternalProject_Add(
4638
extern_flashattn
4739
GIT_REPOSITORY ${FA_REPOSITORY}
4840
GIT_TAG ${FA_TAG}
4941
SOURCE_DIR ${SOURCE_DIR}
50-
PREFIX ${FLASHATTN_PREFIX_DIR}
51-
UPDATE_COMMAND ""
52-
PATCH_COMMAND ""
53-
#BUILD_ALWAYS 1
54-
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc
55-
-DAMDGPU_TARGETS=gfx928
56-
-DCMAKE_CXX_COMPILER_LAUNCHER=${CCACHE_PATH}
57-
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
58-
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
59-
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
60-
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
61-
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
62-
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
63-
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
64-
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
65-
-DWITH_GPU=${WITH_GPU}
66-
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
67-
-DWITH_ROCM=${WITH_ROCM}
68-
-DWITH_OMP=${USE_OMP}
69-
-DBUILD_SHARED=ON
70-
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
71-
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
72-
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
73-
-DCMAKE_JOB_POOLS:STRING=compile=4
74-
${EXTERNAL_OPTIONAL_ARGS}
75-
CMAKE_CACHE_ARGS
76-
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
77-
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
78-
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
79-
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})
42+
CONFIGURE_COMMAND ""
43+
BUILD_COMMAND ""
44+
INSTALL_COMMAND ""
45+
LOG_DOWNLOAD ON)
46+
47+
add_custom_command(
48+
TARGET extern_flashattn
49+
POST_BUILD
50+
COMMAND ${CMAKE_COMMAND} -E make_directory ${FLASHATTN_INCLUDE_DIR}
51+
COMMAND ${CMAKE_COMMAND} -E copy_if_different "${SOURCE_DIR}/flash_attn.h"
52+
${FLASHATTN_INCLUDE_DIR}/
53+
COMMAND ${CMAKE_COMMAND} -E make_directory ${FA_BUILD_DIR}
54+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
55+
"${SOURCE_DIR}/libflashattn.so" ${FA_BUILD_DIR}/
56+
COMMAND ${CMAKE_COMMAND} -E make_directory ${FLASHATTN_LIB_DIR}
57+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
58+
"${SOURCE_DIR}/libflashattn.so" ${FLASHATTN_LIB_DIR}/)
8059
else()
8160

8261
add_definitions(-DPADDLE_WITH_FLASHATTN)

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 31 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,8 @@ void FlashAttnUnpaddedGradBaseKernel(
257257
kdq = &dq_tmp;
258258
}
259259

260-
#ifdef PADDLE_WITH_HIP
261-
std::initializer_list<int64_t> dk_dv_shape = {total_k, num_heads, head_size};
262-
#else
263260
std::initializer_list<int64_t> dk_dv_shape = {
264261
total_k, num_heads_k, num_heads / num_heads_k, head_size};
265-
#endif
266262

267263
DenseTensor *kdk = dk, *kdv = dv;
268264
DenseTensor dk_tmp;
@@ -316,43 +312,6 @@ void FlashAttnUnpaddedGradBaseKernel(
316312

317313
VLOG(10) << "FlashAttn bwd seed: " << params.seed
318314
<< ", offset: " << params.offset;
319-
#ifdef PADDLE_WITH_HIP
320-
bool succ = phi::dynload::flash_attn_varlen_bwd(
321-
dout.data(),
322-
q.data(),
323-
k.data(),
324-
v.data(),
325-
out.data(),
326-
params.softmax_d.data(),
327-
softmax_lse.data(),
328-
cu_seqlens_q.data<int32_t>(),
329-
cu_seqlens_k.data<int32_t>(),
330-
params.rng_state.data(),
331-
kdq->data(),
332-
kdk->data(),
333-
kdv->data(),
334-
params.dq_accum.data(),
335-
params.batch_size,
336-
params.max_seqlen_q,
337-
params.max_seqlen_k,
338-
params.seqlen_q_rounded,
339-
params.seqlen_k_rounded,
340-
params.num_heads,
341-
params.num_heads_k,
342-
params.head_size,
343-
params.head_size_rounded,
344-
params.dropout,
345-
params.softmax_scale,
346-
1.0f / params.softmax_scale,
347-
params.causal,
348-
params.is_bf16,
349-
num_splits,
350-
stream,
351-
params.seed,
352-
params.offset,
353-
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
354-
params.attn_mask_tensor ? params.mask_dims.data() : nullptr);
355-
#else
356315
bool succ = phi::dynload::flash_attn_varlen_bwd(
357316
dout.data(),
358317
q.data(),
@@ -413,56 +372,19 @@ void FlashAttnUnpaddedGradBaseKernel(
413372
max_seqlen_k * kdv->strides()[0],
414373
max_seqlen_q * dout.strides()[0],
415374
varlen_padded);
416-
#endif
417375
CheckFlashAttnStatus(succ);
418376
if (!is_mha) {
419377
if (dk) {
420-
#ifdef PADDLE_WITH_HIP
421-
if (dk->meta().is_contiguous())
422-
phi::SumKernel<T, Context>(
423-
ctx,
424-
dk_tmp.Resize(
425-
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
426-
{2},
427-
dk->type(),
428-
false,
429-
dk);
430-
else
431-
kvReduceForGQA<T, Context>(
432-
ctx,
433-
dk_tmp.Resize(
434-
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
435-
dk);
436-
#else
437378
if (dk->meta().is_contiguous())
438379
phi::SumKernel<T, Context>(ctx, dk_tmp, {2}, dk->type(), false, dk);
439380
else
440381
kvReduceForGQA<T, Context>(ctx, dk_tmp, dk);
441-
#endif
442382
}
443383
if (dv) {
444-
#ifdef PADDLE_WITH_HIP
445-
if (dv->meta().is_contiguous())
446-
phi::SumKernel<T, Context>(
447-
ctx,
448-
dv_tmp.Resize(
449-
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
450-
{2},
451-
dv->type(),
452-
false,
453-
dv);
454-
else
455-
kvReduceForGQA<T, Context>(
456-
ctx,
457-
dv_tmp.Resize(
458-
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
459-
dv);
460-
#else
461384
if (dv->meta().is_contiguous())
462385
phi::SumKernel<T, Context>(ctx, dv_tmp, {2}, dv->type(), false, dv);
463386
else
464387
kvReduceForGQA<T, Context>(ctx, dv_tmp, dv);
465-
#endif
466388
}
467389
}
468390
#else
@@ -658,13 +580,8 @@ void FlashAttnGradBaseKernel(
658580

659581
bool is_mha = (num_heads == num_heads_k);
660582

661-
#ifdef PADDLE_WITH_HIP
662-
std::initializer_list<int64_t> dk_dv_shape = {
663-
batch_size, seqlen_k, num_heads, head_size};
664-
#else
665583
std::initializer_list<int64_t> dk_dv_shape = {
666584
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
667-
#endif
668585

669586
DenseTensor* kdq = dq;
670587
DenseTensor dq_tmp;
@@ -825,7 +742,37 @@ void FlashAttnGradBaseKernel(
825742
params.seed,
826743
params.offset,
827744
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
828-
params.attn_mask_tensor ? params.mask_dims.data() : nullptr);
745+
params.attn_mask_tensor ? params.mask_dims.data() : nullptr,
746+
is_flashmask ? downstart_row_indices_data : nullptr,
747+
is_flashmask ? params.startend_row_indices_dims.data() : nullptr,
748+
is_flashmask ? upend_row_indices_data : nullptr,
749+
is_flashmask ? downend_row_indices_data : nullptr,
750+
is_flashmask ? upstart_row_indices_data : nullptr,
751+
is_flashmask ? flashmask_maxmin.data() : nullptr,
752+
q.strides()[1],
753+
k.strides()[1],
754+
v.strides()[1],
755+
q.strides()[2],
756+
k.strides()[2],
757+
v.strides()[2],
758+
out.strides()[1],
759+
out.strides()[2],
760+
q.strides()[0],
761+
k.strides()[0],
762+
v.strides()[0],
763+
out.strides()[0],
764+
kdq->strides()[1],
765+
kdk->strides()[1],
766+
kdv->strides()[1],
767+
kdq->strides()[2],
768+
kdk->strides()[kdk->strides().size() - 2],
769+
kdv->strides()[kdv->strides().size() - 2],
770+
dout.strides()[1],
771+
dout.strides()[2],
772+
kdq->strides()[0],
773+
kdk->strides()[0],
774+
kdv->strides()[0],
775+
dout.strides()[0]);
829776
#else
830777
bool succ;
831778
int arch =
@@ -981,63 +928,17 @@ void FlashAttnGradBaseKernel(
981928
CheckFlashAttnStatus(succ);
982929
if (!is_mha) {
983930
if (dk) {
984-
#ifdef PADDLE_WITH_HIP
985-
if (dk->meta().is_contiguous())
986-
phi::SumKernel<T, Context>(ctx,
987-
dk_tmp.Resize({batch_size,
988-
seqlen_k,
989-
num_heads_k,
990-
num_heads / num_heads_k,
991-
head_size}),
992-
{3},
993-
dk->type(),
994-
false,
995-
dk);
996-
else
997-
kvReduceBatchedForGQA<T, Context>(
998-
ctx,
999-
dk_tmp.Resize({batch_size,
1000-
seqlen_k,
1001-
num_heads_k,
1002-
num_heads / num_heads_k,
1003-
head_size}),
1004-
dk);
1005-
#else
1006931
if (dk->meta().is_contiguous())
1007932
phi::SumKernel<T, Context>(ctx, dk_tmp, {3}, dk->type(), false, dk);
1008933
else
1009934
kvReduceBatchedForGQA<T, Context>(ctx, dk_tmp, dk);
1010-
#endif
1011935
}
1012936

1013937
if (dv) {
1014-
#ifdef PADDLE_WITH_HIP
1015-
if (dv->meta().is_contiguous())
1016-
phi::SumKernel<T, Context>(ctx,
1017-
dv_tmp.Resize({batch_size,
1018-
seqlen_k,
1019-
num_heads_k,
1020-
num_heads / num_heads_k,
1021-
head_size}),
1022-
{3},
1023-
dv->type(),
1024-
false,
1025-
dv);
1026-
else
1027-
kvReduceBatchedForGQA<T, Context>(
1028-
ctx,
1029-
dv_tmp.Resize({batch_size,
1030-
seqlen_k,
1031-
num_heads_k,
1032-
num_heads / num_heads_k,
1033-
head_size}),
1034-
dv);
1035-
#else
1036938
if (dv->meta().is_contiguous())
1037939
phi::SumKernel<T, Context>(ctx, dv_tmp, {3}, dv->type(), false, dv);
1038940
else
1039941
kvReduceBatchedForGQA<T, Context>(ctx, dv_tmp, dv);
1040-
#endif
1041942
}
1042943
}
1043944
#else

0 commit comments

Comments
 (0)