diff --git a/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc index 75d7a164a9924..21e2df5fff65c 100644 --- a/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc @@ -52,19 +52,25 @@ void GroupNormGradKernel(const Context& dev_ctx, data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); const int group_size = C / groups; - dev_ctx.template Alloc(d_x); phi::funcs::SetConstant set_zero; auto* x_data = y.data(); - auto* d_x_data = d_x->data(); auto* y_data = d_y.data(); auto* var_data = var.data(); + + T* d_x_data = nullptr; + if (d_x) { + dev_ctx.template Alloc(d_x); + d_x_data = d_x->data(); + } + T* d_scale_data = nullptr; if (d_scale) { dev_ctx.template Alloc(d_scale); set_zero(dev_ctx, d_scale, static_cast(0)); d_scale_data = d_scale->data(); } + T* d_bias_data = nullptr; if (d_bias) { dev_ctx.template Alloc(d_bias); @@ -124,22 +130,23 @@ void GroupNormGradKernel(const Context& dev_ctx, d_scale_data[gid * group_size + cid] += val * dval; } } - - for (int cid = 0; cid < number; cid++) { - for (int imid = 0; imid < imsize; - imid++, iter_d_x_data++, tmp_x++, tmp_y++) { - T v_y = tmp_x[0]; - T dly = tmp_y[0]; - T dss = dp_scale; - T dbs = dp_bias; - T v_scale = 1., v_bias = 0.; - if (scale_data) v_scale = scale_data[gid * group_size + cid]; - if (bias_data) v_bias = bias_data[gid * group_size + cid]; - v_y -= v_bias; - if (v_scale != 0) v_y /= v_scale; - iter_d_x_data[0] = - (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * - var_inv; + if (d_x_data) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data++, tmp_x++, tmp_y++) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = 1., v_bias = 0.; + if (scale_data) v_scale = scale_data[gid * group_size + cid]; + if (bias_data) v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } } } } else { @@ -162,35 +169,40 @@ void GroupNormGradKernel(const Context& dev_ctx, d_scale_data[gid * group_size + cid] += val * dval; } } - - for (int cid = 0; cid < number; cid++) { - tmp_x = x_src_data + cid; - tmp_y = y_src_data + cid; - iter_d_x_data = tmp_d_x + cid; - for (int imid = 0; imid < imsize; - imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) { - T v_y = tmp_x[0]; - T dly = tmp_y[0]; - T dss = dp_scale; - T dbs = dp_bias; - T v_scale = 1.0, v_bias = 0.; - if (scale_data) v_scale = scale_data[gid * group_size + cid]; - if (bias_data) v_bias = bias_data[gid * group_size + cid]; - v_y -= v_bias; - if (v_scale != 0) v_y /= v_scale; - iter_d_x_data[0] = - (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * - var_inv; + if (d_x_data) { + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + tmp_y = y_src_data + cid; + iter_d_x_data = tmp_d_x + cid; + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = 1.0, v_bias = 0.; + if (scale_data) v_scale = scale_data[gid * group_size + cid]; + if (bias_data) v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } } } iter_x_data = iter_x_data_backup + group_size; iter_y_data = iter_y_data_backup + group_size; - iter_d_x_data = iter_d_x_data_backup + group_size; + if (d_x_data) { + iter_d_x_data = iter_d_x_data_backup + group_size; + } } } if (data_layout == DataLayout::kNHWC) { iter_x_data = x_data + (bid + 1) * C * imsize; - iter_d_x_data = d_x_data + (bid + 1) * C * imsize; + if (d_x_data) { + iter_d_x_data = d_x_data + (bid + 1) * C * imsize; + } iter_y_data = y_data + (bid + 1) * C * imsize; } } diff --git a/test/dygraph_to_static/test_deal_inplace.py b/test/dygraph_to_static/test_deal_inplace.py index 6cee042de0918..a24efca434256 100644 --- a/test/dygraph_to_static/test_deal_inplace.py +++ b/test/dygraph_to_static/test_deal_inplace.py @@ -90,7 +90,7 @@ def run_test(self, dygraph_fn, *inputs, static_n_times=1): dygraph_out.numpy(), static_out.numpy(), rtol=1e-5, - atol=1e-8, + atol=1e-6, err_msg=f"Run {i}-th check failed.", ) diff --git a/test/legacy_test/test_group_norm_op_v2.py b/test/legacy_test/test_group_norm_op_v2.py index f2449f8544f1b..5b18586e7c7ed 100644 --- a/test/legacy_test/test_group_norm_op_v2.py +++ b/test/legacy_test/test_group_norm_op_v2.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from utils import dygraph_guard import paddle from paddle import base @@ -610,5 +611,55 @@ def test_one_dim_input_static_API(): self.assertRaises(ValueError, test_one_dim_input_static_API) +class TestGroupNormWithOptionalgradX(unittest.TestCase): + def test_group_norm_cpu_with_optional_grad(self): + with dygraph_guard(): + origin_device = paddle.device.get_device() + paddle.device.set_device("cpu") + x = paddle.randn([16, 32]) + x.stop_gradient = False + gpn = paddle.nn.GroupNorm(num_groups=8, num_channels=32) + y = gpn(x) + dw_ref, db_ref, dx_ref = paddle.grad(y, [gpn.weight, gpn.bias, x]) + try: + dw, db, dx = ( + paddle.grad(y, gpn.weight)[0], + paddle.grad(y, gpn.bias)[0], + paddle.grad(y, x)[0], + ) + except Exception as e: + raise e + finally: + paddle.device.set_device(origin_device) + np.testing.assert_equal(dw.numpy(), dw_ref.numpy()) + np.testing.assert_equal(db.numpy(), db_ref.numpy()) + np.testing.assert_equal(dx.numpy(), dx_ref.numpy()) + + def test_group_norm_cpu_with_optional_grad_nhwc(self): + with dygraph_guard(): + origin_device = paddle.device.get_device() + paddle.device.set_device("cpu") + x = paddle.randn([4, 32, 32, 32]) + x.stop_gradient = False + gpn = paddle.nn.GroupNorm( + num_groups=8, num_channels=32, data_format="NHWC" + ) + y = gpn(x) + dw_ref, db_ref, dx_ref = paddle.grad(y, [gpn.weight, gpn.bias, x]) + try: + dw, db, dx = ( + paddle.grad(y, gpn.weight)[0], + paddle.grad(y, gpn.bias)[0], + paddle.grad(y, x)[0], + ) + except Exception as e: + raise e + finally: + paddle.device.set_device(origin_device) + np.testing.assert_equal(dw.numpy(), dw_ref.numpy()) + np.testing.assert_equal(db.numpy(), db_ref.numpy()) + np.testing.assert_equal(dx.numpy(), dx_ref.numpy()) + + if __name__ == '__main__': unittest.main()