From a5ed9f6463f0a83c9a46003418f3e76e84b9b425 Mon Sep 17 00:00:00 2001 From: lj970926 <1783973490@qq.com> Date: Tue, 11 Mar 2025 11:13:44 +0000 Subject: [PATCH 1/2] [XPU] support fp16 for group_norm_grad --- paddle/phi/backends/xpu/xpu3_op_list.cc | 3 +- .../phi/kernels/xpu/group_norm_grad_kernel.cc | 116 +++++++++++++----- test/xpu/test_group_norm_op_xpu.py | 84 ++++++++++++- 3 files changed, 171 insertions(+), 32 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 7a268eca5129a0..2a163ace503404 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -1668,7 +1668,8 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, - {"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"group_norm_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"meshgrid", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, diff --git a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc index 428b2699dc2753..5e077ef46b8903 100644 --- a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc @@ -42,6 +42,8 @@ void GroupNormGradKernel(const Context& dev_ctx, DenseTensor* d_scale, DenseTensor* d_bias) { using XPUType = typename XPUTypeTrait::Type; + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int ret = xpu::SUCCESS; const DataLayout data_layout = common::StringToDataLayout(data_layout_str); const auto scale_ptr = scale.get_ptr(); const auto bias_ptr = bias.get_ptr(); @@ -66,49 +68,107 @@ void GroupNormGradKernel(const Context& dev_ctx, auto* y_data = y.data(); auto* d_x_data = d_x->data(); auto* d_y_data = d_y.data(); - auto* mean_data = mean.data(); - auto* var_data = var.data(); T* d_scale_data = nullptr; + float* d_scale_data_fp32 = nullptr; if (d_scale) { dev_ctx.template Alloc(d_scale); set_zero(dev_ctx, d_scale, static_cast(0)); d_scale_data = d_scale->data(); + if (!std::is_same_v) { + d_scale_data_fp32 = RAII_GUARD.alloc_l3_or_gm(d_scale->numel()); + } else { + d_scale_data_fp32 = reinterpret_cast(d_scale_data); + } } T* d_bias_data = nullptr; + float* d_bias_data_fp32 = nullptr; if (d_bias) { dev_ctx.template Alloc(d_bias); set_zero(dev_ctx, d_bias, static_cast(0)); d_bias_data = d_bias->data(); + if (!std::is_same_v) { + d_bias_data_fp32 = RAII_GUARD.alloc_l3_or_gm(d_bias->numel()); + } else { + d_bias_data_fp32 = reinterpret_cast(d_bias_data); + } } - const T* scale_data = nullptr; - if (scale_ptr) scale_data = scale_ptr->data(); - const T* bias_data = nullptr; - if (bias_ptr) bias_data = bias_ptr->data(); + const float* scale_data = nullptr; + if (scale_ptr) { + float* scale_data_tmp = + RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); + if (!std::is_same_v) { + ret = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(scale_ptr->data()), + scale_data_tmp, + scale_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); + scale_data = scale_data_tmp; + } else { + scale_data = scale_ptr->data(); + } + } + const float* bias_data = nullptr; + if (bias_ptr) { + float* bias_data_tmp = RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); + if (!std::is_same_v) { + ret = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(bias_ptr->data()), + bias_data_tmp, + bias_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); + bias_data = bias_data_tmp; + } else { + bias_data = bias_ptr->data(); + } + } - int r = xpu::group_norm_grad( - dev_ctx.x_context(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - reinterpret_cast(d_y_data), - reinterpret_cast(d_x_data), - N, - C, - L, - 1, - groups, - static_cast(epsilon), - reinterpret_cast(scale_data), - reinterpret_cast(bias_data), - reinterpret_cast(mean_data), - reinterpret_cast(var_data), - reinterpret_cast(d_scale_data), - reinterpret_cast(d_bias_data), - channel_first); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm_grad"); + ret = + xpu::group_norm_grad(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(d_y_data), + reinterpret_cast(d_x_data), + N, + C, + L, + 1, + groups, + epsilon, + scale_data, + bias_data, + mean.data(), + var.data(), + d_scale_data_fp32, + d_bias_data_fp32, + channel_first); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "group_norm_grad"); + if (!std::is_same_v) { + if (d_scale) { + ret = xpu::cast(dev_ctx.x_context(), + d_scale_data_fp32, + reinterpret_cast(d_scale_data), + d_scale->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); + } + + if (d_bias) { + ret = xpu::cast(dev_ctx.x_context(), + d_bias_data_fp32, + reinterpret_cast(d_bias_data), + d_bias->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); + } + } } } // namespace phi -PD_REGISTER_KERNEL( - group_norm_grad, XPU, ALL_LAYOUT, phi::GroupNormGradKernel, float) {} +PD_REGISTER_KERNEL(group_norm_grad, + XPU, + ALL_LAYOUT, + phi::GroupNormGradKernel, + float, + phi::dtype::float16) {} diff --git a/test/xpu/test_group_norm_op_xpu.py b/test/xpu/test_group_norm_op_xpu.py index 6c8a5a0e44280e..e100a3f70dd350 100644 --- a/test/xpu/test_group_norm_op_xpu.py +++ b/test/xpu/test_group_norm_op_xpu.py @@ -19,7 +19,6 @@ from get_test_cover_info import ( XPUOpTestWrapper, create_test_class, - get_xpu_op_support_types, ) from op_test import OpTest from op_test_xpu import XPUOpTest @@ -99,10 +98,89 @@ def init_test_case(self): self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NHWC"} -support_types = get_xpu_op_support_types('group_norm') -for stype in support_types: +for stype in ["float32"]: create_test_class(globals(), XPUTestGroupNormOp, stype) + +class TestGroupNormFP16(unittest.TestCase): + def setUp(self): + self.shape = [2, 100, 3, 5] + self.data_format = "NCHW" + self.epsilon = 1e-5 + self.groups = 2 + + def test_dygraph(self): + paddle.disable_static() + inp = np.random.random(self.shape).astype("float16") + if self.data_format == "NHWC": + inp = np.transpose(inp, (0, 2, 3, 1)) + scale = np.random.random([self.shape[1]]).astype("float16") + bias = np.random.random([self.shape[1]]).astype("float16") + inp_fp16 = paddle.to_tensor(inp, stop_gradient=False) + scale_fp16 = paddle.to_tensor(scale, stop_gradient=False) + bias_fp16 = paddle.to_tensor(bias, stop_gradient=False) + + inp_fp32 = paddle.to_tensor(inp.astype("float32"), stop_gradient=False) + scale_fp32 = paddle.to_tensor( + scale.astype("float32"), stop_gradient=False + ) + bias_fp32 = paddle.to_tensor( + bias.astype("float32"), stop_gradient=False + ) + + out_fp32 = paddle.nn.functional.group_norm( + inp_fp32, + self.groups, + self.epsilon, + scale_fp32, + bias_fp32, + self.data_format, + ) + out_fp32.mean().backward() + inp_grad_fp32 = inp_fp32.grad.numpy() + scale_grad_fp32 = scale_fp32.grad.numpy() + bias_grad_fp32 = bias_fp32.grad.numpy() + + out_fp16 = paddle.nn.functional.group_norm( + inp_fp16, + self.groups, + self.epsilon, + scale_fp16, + bias_fp16, + self.data_format, + ) + out_fp16.mean().backward() + inp_grad_fp16 = inp_fp16.grad.numpy() + scale_grad_fp16 = scale_fp16.grad.numpy() + bias_grad_fp16 = bias_fp16.grad.numpy() + + np.testing.assert_allclose( + out_fp32.numpy(), + out_fp16.numpy().astype("float32"), + atol=0.001, + rtol=0.001, + ) + np.testing.assert_allclose( + inp_grad_fp32, + inp_grad_fp16.astype("float32"), + atol=0.001, + rtol=0.001, + ) + np.testing.assert_allclose( + scale_grad_fp32, + scale_grad_fp16.astype("float32"), + atol=1e-4, + rtol=1e-4, + ) + np.testing.assert_allclose( + bias_grad_fp32, + bias_grad_fp16.astype("float32"), + atol=1e-4, + rtol=1e-4, + ) + paddle.enable_static() + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 1803fdc7f3ee29b1ba5d0462dd1b1465610ee83f Mon Sep 17 00:00:00 2001 From: lj970926 <1783973490@qq.com> Date: Tue, 11 Mar 2025 11:45:26 +0000 Subject: [PATCH 2/2] refine code --- paddle/phi/kernels/xpu/group_norm_grad_kernel.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc index 5e077ef46b8903..dc2074342d780f 100644 --- a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc @@ -95,9 +95,9 @@ void GroupNormGradKernel(const Context& dev_ctx, const float* scale_data = nullptr; if (scale_ptr) { - float* scale_data_tmp = - RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); if (!std::is_same_v) { + float* scale_data_tmp = + RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); ret = xpu::cast( dev_ctx.x_context(), reinterpret_cast(scale_ptr->data()), @@ -111,8 +111,9 @@ void GroupNormGradKernel(const Context& dev_ctx, } const float* bias_data = nullptr; if (bias_ptr) { - float* bias_data_tmp = RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); if (!std::is_same_v) { + float* bias_data_tmp = + RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); ret = xpu::cast( dev_ctx.x_context(), reinterpret_cast(bias_ptr->data()),