Skip to content

Commit ec8f33c

Browse files
authored
Merge pull request #7076 from JiayiFeng/move_enfoce_position
move ENFORCE position
2 parents 3b54948 + a04f30e commit ec8f33c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddle/operators/conv_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
3131
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
3232
int groups = ctx->Attrs().Get<int>("groups");
3333
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
34-
int input_channels = in_dims[1];
35-
int output_channels = filter_dims[0];
3634

3735
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
3836
"Conv intput should be 4-D or 5-D tensor.");
@@ -45,9 +43,13 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
4543
PADDLE_ENFORCE_EQ(
4644
paddings.size(), strides.size(),
4745
"Conv paddings dimension and Conv strides dimension should be the same.");
46+
47+
int input_channels = in_dims[1];
4848
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
4949
"The number of input channels should be equal to filter "
5050
"channels * groups.");
51+
52+
int output_channels = filter_dims[0];
5153
PADDLE_ENFORCE_EQ(
5254
output_channels % groups, 0,
5355
"The number of output channels should be divided by groups.");

0 commit comments

Comments
 (0)