Skip to content

Commit b5334de

Browse files
authored
[XPU] update xhpc to 0405, solving flashmask error in some cases, and add xpu_wait for side_stream (PaddlePaddle#71908)
1 parent 1669cff commit b5334de

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
3030
set(XPU_XPUDNN_LIB_NAME "libxpu_dnn.so")
3131

3232
if(NOT DEFINED XPU_XHPC_BASE_DATE)
33-
set(XPU_XHPC_BASE_DATE "dev/20250318")
33+
set(XPU_XHPC_BASE_DATE "dev/20250405")
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "3.0.2.5") # For XRE5
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ void FlashAttnGradKernelBase(
232232
is_flashmask ? flashmask_stream : nullptr);
233233
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
234234
if (is_flashmask && flashmask_stream != nullptr) {
235+
r = xpu_wait(flashmask_stream);
236+
PADDLE_ENFORCE_XPU_SUCCESS(r);
235237
xpu_stream_destroy(flashmask_stream);
236238
}
237239
}

paddle/phi/kernels/xpu/flash_attn_kernel.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,15 @@ void FlashAttnKernelBase(
210210
(const int*)upstart_row_indices_data, // upstart_row_indices_data
211211
(const int*)upend_row_indices_data, // upend_row_indices_data
212212
is_flashmask ? startend_row_indices->dims()[1]
213-
: 0, // flash_mask_head_num
214-
nullptr, // flashmask_maxmin
215-
is_flashmask ? flashmask_stream : nullptr // side_stream
213+
: 0, // flash_mask_head_num
214+
nullptr, // flashmask_maxmin
215+
is_flashmask ? flashmask_stream : nullptr, // side_stream
216+
0 // fixlen_batch_num
216217
);
217218
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
218219
if (is_flashmask && flashmask_stream != nullptr) {
220+
r = xpu_wait(flashmask_stream);
221+
PADDLE_ENFORCE_XPU_SUCCESS(r);
219222
xpu_stream_destroy(flashmask_stream);
220223
}
221224
}

0 commit comments

Comments
 (0)