Skip to content

Commit e09e21b

Browse files
Merge pull request #6188 from wanghaoshuang/conv_fix
Make ConvTransProjection support for dilation
2 parents fb3e778 + 6173f91 commit e09e21b

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

paddle/gserver/layers/ConvTransProjection.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ size_t ConvTransProjection::calOutputSize() {
2424
if (outputH_ == 0) outputH_ = configOutH_;
2525
if (outputW_ == 0) outputW_ = configOutW_;
2626
imageH_ = imageSize(outputH_,
27-
filterH_,
27+
(filterH_ - 1) * dilationH_ + 1,
2828
paddingH_,
2929
strideH_,
3030
/* caffeMode */ true);
3131

3232
imageW_ = imageSize(outputW_,
33-
filterW_,
33+
(filterW_ - 1) * dilationW_ + 1,
3434
paddingW_,
3535
strideW_,
3636
/* caffeMode */ true);

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,24 @@ void testProjectionConv(size_t groups, bool isDeconv) {
238238
/* caffeMode */ true);
239239
conv->set_output_x(output_x);
240240
conv->set_output_y(output_y);
241+
LOG(INFO) << "DILATION:" << DILATION << "; output_x: " << output_x
242+
<< "; output_y: " << output_y;
241243
if (isDeconv) {
244+
int deconv_image_x = imageSize(output_x,
245+
(conv->filter_size() - 1) * DILATION + 1,
246+
conv->padding(),
247+
conv->stride(),
248+
/* caffeMode */ true);
249+
int deconv_image_y = imageSize(output_y,
250+
(conv->filter_size_y() - 1) * DILATION + 1,
251+
conv->padding_y(),
252+
conv->stride_y(),
253+
/* caffeMode */ true);
254+
255+
LOG(INFO) << " deconv_image_x: " << deconv_image_x
256+
<< "; deconv_image_y: " << deconv_image_y;
242257
conf.set_input_size(output_x * output_y * CHANNELS);
243-
conf.set_output_size(IMAGE_SIZE * IMAGE_SIZE * NUM_FILTERS);
258+
conf.set_output_size(deconv_image_x * deconv_image_y * NUM_FILTERS);
244259
} else {
245260
conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
246261
conf.set_output_size(output_x * output_y * NUM_FILTERS);

0 commit comments

Comments
 (0)