Skip to content

Commit 58f9733

Browse files
authored
[XPU] support fa xte (#68473)
1 parent 9f80c7f commit 58f9733

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

paddle/phi/kernels/xpu/flash_attn_utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ XPU_FA_TGEMM get_flash_attn_tgemm() {
4242
(std::is_same<phi::dtype::float16, T>::value ||
4343
std::is_same<XPUTypeFP16, T>::value)) {
4444
return XPU_FA_TGEMM::FA_FLOAT16;
45-
} else if (std::getenv("XPU_PADDLE_FA_TGEMM_FLOAT")) {
45+
} else if ((std::is_same<phi::dtype::bfloat16, T>::value ||
46+
std::is_same<XPUTypeBF16, T>::value) &&
47+
std::getenv("XPU_PADDLE_FA_BFLOAT16_XTE") != nullptr) {
48+
return XPU_FA_TGEMM::FA_FLOAT16;
49+
} else if (std::getenv("XPU_PADDLE_FA_TGEMM_FLOAT") != nullptr) {
4650
return XPU_FA_TGEMM::FA_FLOAT;
4751
} else {
4852
return XPU_FA_TGEMM::FA_TFLOAT32;

0 commit comments

Comments
 (0)