@@ -52,19 +52,25 @@ void GroupNormGradKernel(const Context& dev_ctx,
52
52
data_layout == DataLayout::kNCHW ? x_dims[1 ] : x_dims[x_dims.size () - 1 ]);
53
53
const int group_size = C / groups;
54
54
55
- dev_ctx.template Alloc <T>(d_x);
56
55
phi::funcs::SetConstant<CPUContext, T> set_zero;
57
56
58
57
auto * x_data = y.data <T>();
59
- auto * d_x_data = d_x->data <T>();
60
58
auto * y_data = d_y.data <T>();
61
59
auto * var_data = var.data <T>();
60
+
61
+ T* d_x_data = nullptr ;
62
+ if (d_x) {
63
+ dev_ctx.template Alloc <T>(d_x);
64
+ d_x_data = d_x->data <T>();
65
+ }
66
+
62
67
T* d_scale_data = nullptr ;
63
68
if (d_scale) {
64
69
dev_ctx.template Alloc <T>(d_scale);
65
70
set_zero (dev_ctx, d_scale, static_cast <T>(0 ));
66
71
d_scale_data = d_scale->data <T>();
67
72
}
73
+
68
74
T* d_bias_data = nullptr ;
69
75
if (d_bias) {
70
76
dev_ctx.template Alloc <T>(d_bias);
@@ -124,22 +130,23 @@ void GroupNormGradKernel(const Context& dev_ctx,
124
130
d_scale_data[gid * group_size + cid] += val * dval;
125
131
}
126
132
}
127
-
128
- for (int cid = 0 ; cid < number; cid++) {
129
- for (int imid = 0 ; imid < imsize;
130
- imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
131
- T v_y = tmp_x[0 ];
132
- T dly = tmp_y[0 ];
133
- T dss = dp_scale;
134
- T dbs = dp_bias;
135
- T v_scale = 1 ., v_bias = 0 .;
136
- if (scale_data) v_scale = scale_data[gid * group_size + cid];
137
- if (bias_data) v_bias = bias_data[gid * group_size + cid];
138
- v_y -= v_bias;
139
- if (v_scale != 0 ) v_y /= v_scale;
140
- iter_d_x_data[0 ] =
141
- (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
142
- var_inv;
133
+ if (d_x_data) {
134
+ for (int cid = 0 ; cid < number; cid++) {
135
+ for (int imid = 0 ; imid < imsize;
136
+ imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
137
+ T v_y = tmp_x[0 ];
138
+ T dly = tmp_y[0 ];
139
+ T dss = dp_scale;
140
+ T dbs = dp_bias;
141
+ T v_scale = 1 ., v_bias = 0 .;
142
+ if (scale_data) v_scale = scale_data[gid * group_size + cid];
143
+ if (bias_data) v_bias = bias_data[gid * group_size + cid];
144
+ v_y -= v_bias;
145
+ if (v_scale != 0 ) v_y /= v_scale;
146
+ iter_d_x_data[0 ] =
147
+ (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
148
+ var_inv;
149
+ }
143
150
}
144
151
}
145
152
} else {
@@ -162,35 +169,42 @@ void GroupNormGradKernel(const Context& dev_ctx,
162
169
d_scale_data[gid * group_size + cid] += val * dval;
163
170
}
164
171
}
165
-
166
- for (int cid = 0 ; cid < number; cid++) {
167
- tmp_x = x_src_data + cid;
168
- tmp_y = y_src_data + cid;
169
- iter_d_x_data = tmp_d_x + cid;
170
- for (int imid = 0 ; imid < imsize;
171
- imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
172
- T v_y = tmp_x[0 ];
173
- T dly = tmp_y[0 ];
174
- T dss = dp_scale;
175
- T dbs = dp_bias;
176
- T v_scale = 1.0 , v_bias = 0 .;
177
- if (scale_data) v_scale = scale_data[gid * group_size + cid];
178
- if (bias_data) v_bias = bias_data[gid * group_size + cid];
179
- v_y -= v_bias;
180
- if (v_scale != 0 ) v_y /= v_scale;
181
- iter_d_x_data[0 ] =
182
- (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
183
- var_inv;
172
+ if (d_x_data) {
173
+ for (int cid = 0 ; cid < number; cid++) {
174
+ tmp_x = x_src_data + cid;
175
+ tmp_y = y_src_data + cid;
176
+ iter_d_x_data = tmp_d_x + cid;
177
+ for (int imid = 0 ; imid < imsize;
178
+ imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
179
+ T v_y = tmp_x[0 ];
180
+ T dly = tmp_y[0 ];
181
+ T dss = dp_scale;
182
+ T dbs = dp_bias;
183
+ T v_scale = 1.0 , v_bias = 0 .;
184
+ if (scale_data) v_scale = scale_data[gid * group_size + cid];
185
+ if (bias_data) v_bias = bias_data[gid * group_size + cid];
186
+ v_y -= v_bias;
187
+ if (v_scale != 0 ) v_y /= v_scale;
188
+ iter_d_x_data[0 ] =
189
+ (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
190
+ var_inv;
191
+ }
184
192
}
185
193
}
186
194
iter_x_data = iter_x_data_backup + group_size;
187
195
iter_y_data = iter_y_data_backup + group_size;
188
- iter_d_x_data = iter_d_x_data_backup + group_size;
196
+ if (d_x_data) {
197
+ iter_d_x_data = iter_d_x_data_backup + group_size;
198
+ }
189
199
}
190
200
}
191
201
if (data_layout == DataLayout::kNHWC ) {
192
- iter_x_data = x_data + (bid + 1 ) * C * imsize;
193
- iter_d_x_data = d_x_data + (bid + 1 ) * C * imsize;
202
+ if (x_data) {
203
+ iter_x_data = x_data + (bid + 1 ) * C * imsize;
204
+ }
205
+ if (d_x_data) {
206
+ iter_d_x_data = d_x_data + (bid + 1 ) * C * imsize;
207
+ }
194
208
iter_y_data = y_data + (bid + 1 ) * C * imsize;
195
209
}
196
210
}
0 commit comments