Skip to content

Commit ecdc6e6

Browse files
fix GroupNormGradKernel when d_x is nullptr
1 parent 3df7186 commit ecdc6e6

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

paddle/phi/kernels/cpu/group_norm_grad_kernel.cc

+54-40
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,25 @@ void GroupNormGradKernel(const Context& dev_ctx,
5252
data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]);
5353
const int group_size = C / groups;
5454

55-
dev_ctx.template Alloc<T>(d_x);
5655
phi::funcs::SetConstant<CPUContext, T> set_zero;
5756

5857
auto* x_data = y.data<T>();
59-
auto* d_x_data = d_x->data<T>();
6058
auto* y_data = d_y.data<T>();
6159
auto* var_data = var.data<T>();
60+
61+
T* d_x_data = nullptr;
62+
if (d_x) {
63+
dev_ctx.template Alloc<T>(d_x);
64+
d_x_data = d_x->data<T>();
65+
}
66+
6267
T* d_scale_data = nullptr;
6368
if (d_scale) {
6469
dev_ctx.template Alloc<T>(d_scale);
6570
set_zero(dev_ctx, d_scale, static_cast<T>(0));
6671
d_scale_data = d_scale->data<T>();
6772
}
73+
6874
T* d_bias_data = nullptr;
6975
if (d_bias) {
7076
dev_ctx.template Alloc<T>(d_bias);
@@ -124,22 +130,23 @@ void GroupNormGradKernel(const Context& dev_ctx,
124130
d_scale_data[gid * group_size + cid] += val * dval;
125131
}
126132
}
127-
128-
for (int cid = 0; cid < number; cid++) {
129-
for (int imid = 0; imid < imsize;
130-
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
131-
T v_y = tmp_x[0];
132-
T dly = tmp_y[0];
133-
T dss = dp_scale;
134-
T dbs = dp_bias;
135-
T v_scale = 1., v_bias = 0.;
136-
if (scale_data) v_scale = scale_data[gid * group_size + cid];
137-
if (bias_data) v_bias = bias_data[gid * group_size + cid];
138-
v_y -= v_bias;
139-
if (v_scale != 0) v_y /= v_scale;
140-
iter_d_x_data[0] =
141-
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
142-
var_inv;
133+
if (d_x_data) {
134+
for (int cid = 0; cid < number; cid++) {
135+
for (int imid = 0; imid < imsize;
136+
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
137+
T v_y = tmp_x[0];
138+
T dly = tmp_y[0];
139+
T dss = dp_scale;
140+
T dbs = dp_bias;
141+
T v_scale = 1., v_bias = 0.;
142+
if (scale_data) v_scale = scale_data[gid * group_size + cid];
143+
if (bias_data) v_bias = bias_data[gid * group_size + cid];
144+
v_y -= v_bias;
145+
if (v_scale != 0) v_y /= v_scale;
146+
iter_d_x_data[0] =
147+
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
148+
var_inv;
149+
}
143150
}
144151
}
145152
} else {
@@ -162,35 +169,42 @@ void GroupNormGradKernel(const Context& dev_ctx,
162169
d_scale_data[gid * group_size + cid] += val * dval;
163170
}
164171
}
165-
166-
for (int cid = 0; cid < number; cid++) {
167-
tmp_x = x_src_data + cid;
168-
tmp_y = y_src_data + cid;
169-
iter_d_x_data = tmp_d_x + cid;
170-
for (int imid = 0; imid < imsize;
171-
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
172-
T v_y = tmp_x[0];
173-
T dly = tmp_y[0];
174-
T dss = dp_scale;
175-
T dbs = dp_bias;
176-
T v_scale = 1.0, v_bias = 0.;
177-
if (scale_data) v_scale = scale_data[gid * group_size + cid];
178-
if (bias_data) v_bias = bias_data[gid * group_size + cid];
179-
v_y -= v_bias;
180-
if (v_scale != 0) v_y /= v_scale;
181-
iter_d_x_data[0] =
182-
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
183-
var_inv;
172+
if (d_x_data) {
173+
for (int cid = 0; cid < number; cid++) {
174+
tmp_x = x_src_data + cid;
175+
tmp_y = y_src_data + cid;
176+
iter_d_x_data = tmp_d_x + cid;
177+
for (int imid = 0; imid < imsize;
178+
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
179+
T v_y = tmp_x[0];
180+
T dly = tmp_y[0];
181+
T dss = dp_scale;
182+
T dbs = dp_bias;
183+
T v_scale = 1.0, v_bias = 0.;
184+
if (scale_data) v_scale = scale_data[gid * group_size + cid];
185+
if (bias_data) v_bias = bias_data[gid * group_size + cid];
186+
v_y -= v_bias;
187+
if (v_scale != 0) v_y /= v_scale;
188+
iter_d_x_data[0] =
189+
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
190+
var_inv;
191+
}
184192
}
185193
}
186194
iter_x_data = iter_x_data_backup + group_size;
187195
iter_y_data = iter_y_data_backup + group_size;
188-
iter_d_x_data = iter_d_x_data_backup + group_size;
196+
if (d_x_data) {
197+
iter_d_x_data = iter_d_x_data_backup + group_size;
198+
}
189199
}
190200
}
191201
if (data_layout == DataLayout::kNHWC) {
192-
iter_x_data = x_data + (bid + 1) * C * imsize;
193-
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
202+
if (x_data) {
203+
iter_x_data = x_data + (bid + 1) * C * imsize;
204+
}
205+
if (d_x_data) {
206+
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
207+
}
194208
iter_y_data = y_data + (bid + 1) * C * imsize;
195209
}
196210
}

0 commit comments

Comments
 (0)