We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9f80c7f commit 58f9733Copy full SHA for 58f9733
paddle/phi/kernels/xpu/flash_attn_utils.h
@@ -42,7 +42,11 @@ XPU_FA_TGEMM get_flash_attn_tgemm() {
42
(std::is_same<phi::dtype::float16, T>::value ||
43
std::is_same<XPUTypeFP16, T>::value)) {
44
return XPU_FA_TGEMM::FA_FLOAT16;
45
- } else if (std::getenv("XPU_PADDLE_FA_TGEMM_FLOAT")) {
+ } else if ((std::is_same<phi::dtype::bfloat16, T>::value ||
46
+ std::is_same<XPUTypeBF16, T>::value) &&
47
+ std::getenv("XPU_PADDLE_FA_BFLOAT16_XTE") != nullptr) {
48
+ return XPU_FA_TGEMM::FA_FLOAT16;
49
+ } else if (std::getenv("XPU_PADDLE_FA_TGEMM_FLOAT") != nullptr) {
50
return XPU_FA_TGEMM::FA_FLOAT;
51
} else {
52
return XPU_FA_TGEMM::FA_TFLOAT32;
0 commit comments