From f6179267bd125ebadaad3ac79c1a83f140f4b247 Mon Sep 17 00:00:00 2001 From: skywalker2012 <108259496+skywalker2012@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:17:15 +0800 Subject: [PATCH 1/3] support argmin bf16 --- paddle/phi/backends/xpu/xpu3_op_list.cc | 4 +++- paddle/phi/kernels/xpu/arg_min_max_kernel.cc | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 92b66f843a4195..f788350c941578 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -49,7 +49,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"arg_min", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"argsort_grad", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index 693e0ba8070edc..3152116a49a77c 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -196,7 +196,12 @@ PD_REGISTER_KERNEL(argmax, kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } -PD_REGISTER_KERNEL( - argmin, XPU, ALL_LAYOUT, phi::ArgMinKernel, float, phi::dtype::float16) { +PD_REGISTER_KERNEL(argmin, + XPU, + ALL_LAYOUT, + phi::ArgMinKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } From c5460985f723742acc1198c564f34f2802147fbc Mon Sep 17 00:00:00 2001 From: skywalker2012 <108259496+skywalker2012@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:16:47 +0800 Subject: [PATCH 2/3] convert_float_to_uint16 --- test/xpu/test_arg_min_op_xpu.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/xpu/test_arg_min_op_xpu.py b/test/xpu/test_arg_min_op_xpu.py index b1dcb2e93dcc47..56518deddf9dc8 100644 --- a/test/xpu/test_arg_min_op_xpu.py +++ b/test/xpu/test_arg_min_op_xpu.py @@ -20,6 +20,7 @@ create_test_class, get_xpu_op_support_types, ) +from op_test import convert_float_to_uint16 from op_test_xpu import XPUOpTest import paddle @@ -41,8 +42,13 @@ def setUp(self): self.dtype = self.in_type self.initTestCase() - self.x = (np.random.random(self.dims)).astype(self.dtype) - self.inputs = {'X': self.x} + self.x = (np.random.random(self.dims)).astype( + self.dtype if self.dtype != np.uint16 else np.float32 + ) + + self.inputs = {'X': self.x + if self.dtype != np.uint16 + else convert_float_to_uint16(self.x)} self.attrs = {'axis': self.axis, 'use_xpu': True} self.outputs = {'Out': np.argmin(self.x, axis=self.axis)} From 3b3de34c2744c53acd4b21ac832361b054f68768 Mon Sep 17 00:00:00 2001 From: skywalker2012 <108259496+skywalker2012@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:48:54 +0800 Subject: [PATCH 3/3] pre-commit modify --- test/xpu/test_arg_min_op_xpu.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/xpu/test_arg_min_op_xpu.py b/test/xpu/test_arg_min_op_xpu.py index 56518deddf9dc8..5d1de55037c6b8 100644 --- a/test/xpu/test_arg_min_op_xpu.py +++ b/test/xpu/test_arg_min_op_xpu.py @@ -46,9 +46,13 @@ def setUp(self): self.dtype if self.dtype != np.uint16 else np.float32 ) - self.inputs = {'X': self.x - if self.dtype != np.uint16 - else convert_float_to_uint16(self.x)} + self.inputs = { + 'X': ( + self.x + if self.dtype != np.uint16 + else convert_float_to_uint16(self.x) + ) + } self.attrs = {'axis': self.axis, 'use_xpu': True} self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}