Skip to content

Commit 2f5d0d0

Browse files
[Kernel] Fix GroupNormGradKernel when d_x is nullptr (#72358)
* fix GroupNormGradKernel when d_x is nullptr * update UT * fix relative import * fix coverage * fix rtol and atol * fix for coverage * update UT * add nhwc for coverage
1 parent 1c5dc6e commit 2f5d0d0

File tree

2 files changed

+102
-39
lines changed

2 files changed

+102
-39
lines changed

paddle/phi/kernels/cpu/group_norm_grad_kernel.cc

+51-39
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,25 @@ void GroupNormGradKernel(const Context& dev_ctx,
5252
data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]);
5353
const int group_size = C / groups;
5454

55-
dev_ctx.template Alloc<T>(d_x);
5655
phi::funcs::SetConstant<CPUContext, T> set_zero;
5756

5857
auto* x_data = y.data<T>();
59-
auto* d_x_data = d_x->data<T>();
6058
auto* y_data = d_y.data<T>();
6159
auto* var_data = var.data<T>();
60+
61+
T* d_x_data = nullptr;
62+
if (d_x) {
63+
dev_ctx.template Alloc<T>(d_x);
64+
d_x_data = d_x->data<T>();
65+
}
66+
6267
T* d_scale_data = nullptr;
6368
if (d_scale) {
6469
dev_ctx.template Alloc<T>(d_scale);
6570
set_zero(dev_ctx, d_scale, static_cast<T>(0));
6671
d_scale_data = d_scale->data<T>();
6772
}
73+
6874
T* d_bias_data = nullptr;
6975
if (d_bias) {
7076
dev_ctx.template Alloc<T>(d_bias);
@@ -124,22 +130,23 @@ void GroupNormGradKernel(const Context& dev_ctx,
124130
d_scale_data[gid * group_size + cid] += val * dval;
125131
}
126132
}
127-
128-
for (int cid = 0; cid < number; cid++) {
129-
for (int imid = 0; imid < imsize;
130-
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
131-
T v_y = tmp_x[0];
132-
T dly = tmp_y[0];
133-
T dss = dp_scale;
134-
T dbs = dp_bias;
135-
T v_scale = 1., v_bias = 0.;
136-
if (scale_data) v_scale = scale_data[gid * group_size + cid];
137-
if (bias_data) v_bias = bias_data[gid * group_size + cid];
138-
v_y -= v_bias;
139-
if (v_scale != 0) v_y /= v_scale;
140-
iter_d_x_data[0] =
141-
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
142-
var_inv;
133+
if (d_x_data) {
134+
for (int cid = 0; cid < number; cid++) {
135+
for (int imid = 0; imid < imsize;
136+
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
137+
T v_y = tmp_x[0];
138+
T dly = tmp_y[0];
139+
T dss = dp_scale;
140+
T dbs = dp_bias;
141+
T v_scale = 1., v_bias = 0.;
142+
if (scale_data) v_scale = scale_data[gid * group_size + cid];
143+
if (bias_data) v_bias = bias_data[gid * group_size + cid];
144+
v_y -= v_bias;
145+
if (v_scale != 0) v_y /= v_scale;
146+
iter_d_x_data[0] =
147+
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
148+
var_inv;
149+
}
143150
}
144151
}
145152
} else {
@@ -162,35 +169,40 @@ void GroupNormGradKernel(const Context& dev_ctx,
162169
d_scale_data[gid * group_size + cid] += val * dval;
163170
}
164171
}
165-
166-
for (int cid = 0; cid < number; cid++) {
167-
tmp_x = x_src_data + cid;
168-
tmp_y = y_src_data + cid;
169-
iter_d_x_data = tmp_d_x + cid;
170-
for (int imid = 0; imid < imsize;
171-
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
172-
T v_y = tmp_x[0];
173-
T dly = tmp_y[0];
174-
T dss = dp_scale;
175-
T dbs = dp_bias;
176-
T v_scale = 1.0, v_bias = 0.;
177-
if (scale_data) v_scale = scale_data[gid * group_size + cid];
178-
if (bias_data) v_bias = bias_data[gid * group_size + cid];
179-
v_y -= v_bias;
180-
if (v_scale != 0) v_y /= v_scale;
181-
iter_d_x_data[0] =
182-
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
183-
var_inv;
172+
if (d_x_data) {
173+
for (int cid = 0; cid < number; cid++) {
174+
tmp_x = x_src_data + cid;
175+
tmp_y = y_src_data + cid;
176+
iter_d_x_data = tmp_d_x + cid;
177+
for (int imid = 0; imid < imsize;
178+
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
179+
T v_y = tmp_x[0];
180+
T dly = tmp_y[0];
181+
T dss = dp_scale;
182+
T dbs = dp_bias;
183+
T v_scale = 1.0, v_bias = 0.;
184+
if (scale_data) v_scale = scale_data[gid * group_size + cid];
185+
if (bias_data) v_bias = bias_data[gid * group_size + cid];
186+
v_y -= v_bias;
187+
if (v_scale != 0) v_y /= v_scale;
188+
iter_d_x_data[0] =
189+
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
190+
var_inv;
191+
}
184192
}
185193
}
186194
iter_x_data = iter_x_data_backup + group_size;
187195
iter_y_data = iter_y_data_backup + group_size;
188-
iter_d_x_data = iter_d_x_data_backup + group_size;
196+
if (d_x_data) {
197+
iter_d_x_data = iter_d_x_data_backup + group_size;
198+
}
189199
}
190200
}
191201
if (data_layout == DataLayout::kNHWC) {
192202
iter_x_data = x_data + (bid + 1) * C * imsize;
193-
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
203+
if (d_x_data) {
204+
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
205+
}
194206
iter_y_data = y_data + (bid + 1) * C * imsize;
195207
}
196208
}

test/legacy_test/test_group_norm_op_v2.py

+51
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717

1818
import numpy as np
19+
from utils import dygraph_guard
1920

2021
import paddle
2122
from paddle import base
@@ -610,5 +611,55 @@ def test_one_dim_input_static_API():
610611
self.assertRaises(ValueError, test_one_dim_input_static_API)
611612

612613

614+
class TestGroupNormWithOptionalgradX(unittest.TestCase):
615+
def test_group_norm_cpu_with_optional_grad(self):
616+
with dygraph_guard():
617+
origin_device = paddle.device.get_device()
618+
paddle.device.set_device("cpu")
619+
x = paddle.randn([16, 32])
620+
x.stop_gradient = False
621+
gpn = paddle.nn.GroupNorm(num_groups=8, num_channels=32)
622+
y = gpn(x)
623+
dw_ref, db_ref, dx_ref = paddle.grad(y, [gpn.weight, gpn.bias, x])
624+
try:
625+
dw, db, dx = (
626+
paddle.grad(y, gpn.weight)[0],
627+
paddle.grad(y, gpn.bias)[0],
628+
paddle.grad(y, x)[0],
629+
)
630+
except Exception as e:
631+
raise e
632+
finally:
633+
paddle.device.set_device(origin_device)
634+
np.testing.assert_equal(dw.numpy(), dw_ref.numpy())
635+
np.testing.assert_equal(db.numpy(), db_ref.numpy())
636+
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())
637+
638+
def test_group_norm_cpu_with_optional_grad_nhwc(self):
639+
with dygraph_guard():
640+
origin_device = paddle.device.get_device()
641+
paddle.device.set_device("cpu")
642+
x = paddle.randn([4, 32, 32, 32])
643+
x.stop_gradient = False
644+
gpn = paddle.nn.GroupNorm(
645+
num_groups=8, num_channels=32, data_format="NHWC"
646+
)
647+
y = gpn(x)
648+
dw_ref, db_ref, dx_ref = paddle.grad(y, [gpn.weight, gpn.bias, x])
649+
try:
650+
dw, db, dx = (
651+
paddle.grad(y, gpn.weight)[0],
652+
paddle.grad(y, gpn.bias)[0],
653+
paddle.grad(y, x)[0],
654+
)
655+
except Exception as e:
656+
raise e
657+
finally:
658+
paddle.device.set_device(origin_device)
659+
np.testing.assert_equal(dw.numpy(), dw_ref.numpy())
660+
np.testing.assert_equal(db.numpy(), db_ref.numpy())
661+
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())
662+
663+
613664
if __name__ == '__main__':
614665
unittest.main()

0 commit comments

Comments
 (0)