Skip to content

Commit c666302

Browse files
committed
concat op support negative axis (PaddlePaddle#18045)
test=release/1.5
1 parent 7e31d5a commit c666302

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

paddle/fluid/operators/concat_op.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ class ConcatOp : public framework::OperatorWithKernel {
3636
"Output(Out) of ConcatOp should not be null.");
3737

3838
auto ins = ctx->GetInputsDim("X");
39-
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
39+
size_t axis =
40+
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
41+
static_cast<int64_t>(ins[0].size()));
42+
4043
const size_t n = ins.size();
4144

4245
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
@@ -115,7 +118,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
115118
"(bool, default false) Indicates if MKL-DNN kernel will be used")
116119
.SetDefault(false);
117120
AddAttr<int>("axis",
118-
"The axis along which the input tensors will be concatenated.")
121+
"The axis along which the input tensors will be concatenated."
122+
"The axis could also be negative numbers. Negative axis is "
123+
"interpreted as counting from the end of the rank."
124+
"i.e., axis + rank(X) th dimension.")
119125
.SetDefault(0);
120126
AddAttr<bool>("use_quantizer",
121127
"(bool, default false) "

paddle/fluid/operators/concat_op.h

+13-3
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,22 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace operators {
2525

26+
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
27+
if (axis < 0) {
28+
axis = axis + rank;
29+
}
30+
return axis > 0 ? axis : 0;
31+
}
32+
2633
template <typename DeviceContext, typename T>
2734
class ConcatKernel : public framework::OpKernel<T> {
2835
public:
2936
void Compute(const framework::ExecutionContext& ctx) const override {
3037
auto ins = ctx.MultiInput<framework::Tensor>("X");
3138
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
32-
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
39+
PADDLE_ENFORCE(ins[0], "The input should not be null.");
40+
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
41+
static_cast<int64_t>(ins[0]->dims().size()));
3342
auto place = ctx.GetPlace();
3443
out->mutable_data<T>(place);
3544

@@ -83,8 +92,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
8392
}
8493
}
8594
}
86-
87-
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
95+
PADDLE_ENFORCE(ins[0], "The input should not be null.");
96+
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
97+
static_cast<int64_t>(ins[0]->dims().size()));
8898

8999
// get output tensor that the name is not kEmptyVarName
90100
std::vector<framework::Tensor*> outputs;

python/paddle/fluid/tests/unittests/test_concat_op.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@ def setUp(self):
2525
self.init_test_data()
2626
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
2727
self.attrs = {'axis': self.axis}
28+
if self.axis < 0:
29+
self.actual_axis = self.axis + len(self.x0.shape)
30+
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
31+
else:
32+
self.actual_axis = self.axis
33+
2834
self.outputs = {
2935
'Out': np.concatenate(
30-
(self.x0, self.x1, self.x2), axis=self.axis)
36+
(self.x0, self.x1, self.x2), axis=self.actual_axis)
3137
}
3238

3339
def test_check_output(self):
@@ -75,5 +81,13 @@ def test_check_grad(self):
7581
pass
7682

7783

84+
class TestConcatOp5(TestConcatOp):
85+
def init_test_data(self):
86+
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
87+
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
88+
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
89+
self.axis = -3
90+
91+
7892
if __name__ == '__main__':
7993
unittest.main()

0 commit comments

Comments
 (0)