Skip to content

Commit 3a0260c

Browse files
authored
[XPU] gelu and gelu_grad support bfloat16 type (PaddlePaddle#71809)
1 parent e92c3f5 commit 3a0260c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,13 @@ XPUOpMap& get_kl3_ops() {
762762
phi::DataType::FLOAT16,
763763
phi::DataType::BFLOAT16})},
764764
{"gelu_grad",
765-
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
766-
{"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
765+
XPUKernelSet({phi::DataType::FLOAT32,
766+
phi::DataType::FLOAT16,
767+
phi::DataType::BFLOAT16})},
768+
{"gelu",
769+
XPUKernelSet({phi::DataType::FLOAT32,
770+
phi::DataType::FLOAT16,
771+
phi::DataType::BFLOAT16})},
767772
{"generate_sequence_xpu",
768773
XPUKernelSet({
769774
phi::DataType::FLOAT32,

paddle/phi/kernels/xpu/gelu_grad_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ PD_REGISTER_KERNEL(gelu_grad,
4545
ALL_LAYOUT,
4646
phi::GeluGradKernel,
4747
float,
48-
phi::dtype::float16) {}
48+
phi::dtype::float16,
49+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/gelu_kernel.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,10 @@ void GeluKernel(const Context& dev_ctx,
3838
}
3939
} // namespace phi
4040

41-
PD_REGISTER_KERNEL(
42-
gelu, XPU, ALL_LAYOUT, phi::GeluKernel, float, phi::dtype::float16) {}
41+
PD_REGISTER_KERNEL(gelu,
42+
XPU,
43+
ALL_LAYOUT,
44+
phi::GeluKernel,
45+
float,
46+
phi::dtype::float16,
47+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)