Skip to content

Commit b27863c

Browse files
committed
add elementwise floordiv, mod; test=develop
1 parent e61d724 commit b27863c

10 files changed

+386
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
16+
#include <string>
17+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
class ElementwiseFloorDivOpMaker : public ElementwiseOpMaker {
22+
protected:
23+
std::string GetName() const override { return "FloorDiv"; }
24+
std::string GetEquation() const override { return "Out = X % Y"; }
25+
};
26+
} // namespace operators
27+
} // namespace paddle
28+
29+
namespace ops = paddle::operators;
30+
31+
REGISTER_OP_WITHOUT_GRADIENT(elementwise_floordiv, ops::ElementwiseOp,
32+
ops::ElementwiseFloorDivOpMaker);
33+
34+
REGISTER_OP_CPU_KERNEL(
35+
elementwise_floordiv,
36+
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>,
37+
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext,
38+
int64_t>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
15+
#include "paddle/fluid/platform/float16.h"
16+
17+
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
elementwise_floordiv,
22+
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
23+
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int64_t>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
19+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
20+
#include "paddle/fluid/operators/math/blas.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename T>
26+
struct FloorDivFunctor {
27+
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
28+
};
29+
30+
template <typename DeviceContext, typename T>
31+
void elementwise_floor_div(const framework::ExecutionContext &ctx,
32+
const framework::Tensor *x,
33+
const framework::Tensor *y, framework::Tensor *z) {
34+
int axis = ctx.Attr<int>("axis");
35+
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
36+
ctx, x, y, axis, FloorDivFunctor<T>(), z);
37+
}
38+
39+
template <typename DeviceContext, typename T>
40+
class ElementwiseFloorDivKernel : public framework::OpKernel<T> {
41+
public:
42+
void Compute(const framework::ExecutionContext &ctx) const override {
43+
auto *x = ctx.Input<framework::LoDTensor>("X");
44+
auto *y = ctx.Input<framework::LoDTensor>("Y");
45+
auto *z = ctx.Output<framework::LoDTensor>("Out");
46+
47+
z->mutable_data<T>(ctx.GetPlace());
48+
49+
// dtype of x and y is int64 or int32
50+
elementwise_floor_div<DeviceContext, T>(ctx, x, y, z);
51+
}
52+
};
53+
54+
} // namespace operators
55+
} // namespace paddle
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
16+
#include <string>
17+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
class ElementwiseModOpMaker : public ElementwiseOpMaker {
22+
protected:
23+
std::string GetName() const override { return "Mod"; }
24+
std::string GetEquation() const override { return "Out = X % Y"; }
25+
};
26+
} // namespace operators
27+
} // namespace paddle
28+
29+
namespace ops = paddle::operators;
30+
REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp,
31+
ops::ElementwiseModOpMaker);
32+
33+
REGISTER_OP_CPU_KERNEL(
34+
elementwise_mod,
35+
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
36+
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
15+
#include "paddle/fluid/platform/float16.h"
16+
17+
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
22+
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
19+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
20+
#include "paddle/fluid/operators/math/blas.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename T>
26+
struct ModFunctor {
27+
inline HOSTDEVICE T operator()(T a, T b) const { return a % b; }
28+
};
29+
30+
template <typename DeviceContext, typename T>
31+
void elementwise_mod(const framework::ExecutionContext &ctx,
32+
const framework::Tensor *x, const framework::Tensor *y,
33+
framework::Tensor *z) {
34+
int axis = ctx.Attr<int>("axis");
35+
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
36+
ModFunctor<T>(), z);
37+
}
38+
39+
template <typename DeviceContext, typename T>
40+
class ElementwiseModKernel : public framework::OpKernel<T> {
41+
public:
42+
void Compute(const framework::ExecutionContext &ctx) const override {
43+
auto *x = ctx.Input<framework::LoDTensor>("X");
44+
auto *y = ctx.Input<framework::LoDTensor>("Y");
45+
auto *z = ctx.Output<framework::LoDTensor>("Out");
46+
47+
z->mutable_data<T>(ctx.GetPlace());
48+
49+
// dtype of x and y is int64 or int32
50+
elementwise_mod<DeviceContext, T>(ctx, x, y, z);
51+
}
52+
};
53+
54+
} // namespace operators
55+
} // namespace paddle

python/paddle/fluid/layers/math_op_patch.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def __impl__(self, other_var):
174174
("__rtruediv__", "elementwise_div", True),
175175
("__pow__", "elementwise_pow", False),
176176
("__rpow__", "elementwise_pow", True),
177+
("__floordiv__", "elementwise_floordiv", False),
178+
("__mod__", "elementwise_mod", False),
177179
# for logical compare
178180
("__eq__", "equal", False),
179181
("__ne__", "not_equal", False),

python/paddle/fluid/layers/nn.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -8887,9 +8887,24 @@ def elementwise_pow(x, y, axis=-1, act=None, name=None):
88878887
return _elementwise_op(LayerHelper('elementwise_pow', **locals()))
88888888

88898889

8890+
def elementwise_mod(x, y, axis=-1, act=None, name=None):
8891+
return _elementwise_op(LayerHelper('elementwise_mod', **locals()))
8892+
8893+
8894+
def elementwise_floordiv(x, y, axis=-1, act=None, name=None):
8895+
return _elementwise_op(LayerHelper('elementwise_floordiv', **locals()))
8896+
8897+
88908898
for func in [
8891-
elementwise_add, elementwise_div, elementwise_sub, elementwise_mul,
8892-
elementwise_max, elementwise_min, elementwise_pow
8899+
elementwise_add,
8900+
elementwise_div,
8901+
elementwise_sub,
8902+
elementwise_mul,
8903+
elementwise_max,
8904+
elementwise_min,
8905+
elementwise_pow,
8906+
elementwise_mod,
8907+
elementwise_floordiv,
88938908
]:
88948909
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
88958910
func.__doc__ = _generate_doc_string_(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import unittest
17+
import numpy as np
18+
import paddle.fluid.core as core
19+
from op_test import OpTest
20+
21+
import random
22+
23+
24+
class TestElementwiseModOp(OpTest):
25+
def init_kernel_type(self):
26+
self.use_mkldnn = False
27+
28+
def setUp(self):
29+
self.op_type = "elementwise_floordiv"
30+
self.dtype = np.int32
31+
self.axis = -1
32+
self.init_dtype()
33+
self.init_input_output()
34+
self.init_kernel_type()
35+
self.init_axis()
36+
37+
self.inputs = {
38+
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
39+
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
40+
}
41+
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
42+
self.outputs = {'Out': self.out}
43+
44+
def test_check_output(self):
45+
self.check_output()
46+
47+
def init_input_output(self):
48+
self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)
49+
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
50+
self.out = np.floor_divide(self.x, self.y)
51+
52+
def init_dtype(self):
53+
pass
54+
55+
def init_axis(self):
56+
pass
57+
58+
59+
class TestElementwiseModOp_scalar(TestElementwiseModOp):
60+
def init_input_output(self):
61+
scale_x = random.randint(0, 100000000)
62+
scale_y = random.randint(1, 100000000)
63+
self.x = (np.random.rand(2, 3, 4) * scale_x).astype(self.dtype)
64+
self.y = (np.random.rand(1) * scale_y + 1).astype(self.dtype)
65+
self.out = np.floor_divide(self.x, self.y)
66+
67+
68+
if __name__ == '__main__':
69+
unittest.main()

0 commit comments

Comments
 (0)