Skip to content

Commit 418060a

Browse files
authored
fix the remainder (#26995) (#27026)
1 parent 897e574 commit 418060a

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

paddle/fluid/operators/elementwise/elementwise_mod_op.h

+36-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ struct ModFunctor {
3131
}
3232
};
3333

34+
template <typename T>
35+
struct InverseModFunctor {
36+
inline HOSTDEVICE T operator()(T a, T b) const {
37+
T res = b % a;
38+
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
39+
return res;
40+
}
41+
};
42+
3443
template <typename T>
3544
struct ModFunctorFP {
3645
inline HOSTDEVICE T operator()(T a, T b) const {
@@ -40,22 +49,45 @@ struct ModFunctorFP {
4049
}
4150
};
4251

52+
template <typename T>
53+
struct InverseModFunctorFP {
54+
inline HOSTDEVICE T operator()(T a, T b) const {
55+
T res = fmod(b, a);
56+
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
57+
return res;
58+
}
59+
};
60+
4361
template <typename DeviceContext, typename T>
4462
void elementwise_mod(const framework::ExecutionContext &ctx,
4563
const framework::Tensor *x, const framework::Tensor *y,
4664
framework::Tensor *z) {
4765
int axis = ctx.Attr<int>("axis");
48-
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
49-
ModFunctor<T>(), z);
66+
auto x_dims = x->dims();
67+
auto y_dims = y->dims();
68+
if (x_dims.size() >= y_dims.size()) {
69+
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
70+
ModFunctor<T>(), z);
71+
} else {
72+
ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
73+
ctx, x, y, axis, InverseModFunctor<T>(), z);
74+
}
5075
}
5176

5277
template <typename DeviceContext, typename T>
5378
void elementwise_mod_fp(const framework::ExecutionContext &ctx,
5479
const framework::Tensor *x, const framework::Tensor *y,
5580
framework::Tensor *z) {
5681
int axis = ctx.Attr<int>("axis");
57-
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis,
58-
ModFunctorFP<T>(), z);
82+
auto x_dims = x->dims();
83+
auto y_dims = y->dims();
84+
if (x_dims.size() >= y_dims.size()) {
85+
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(
86+
ctx, x, y, axis, ModFunctorFP<T>(), z);
87+
} else {
88+
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
89+
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
90+
}
5991
}
6092

6193
template <typename DeviceContext, typename T>

python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py

+8
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ def test_dygraph(self):
220220
z_expected = np.array([0, 1, 1, -1])
221221
self.assertEqual(np.allclose(z_expected, z.numpy()), True)
222222

223+
np_x = np.array([-3, 3])
224+
np_y = np.array([[2, 3], [-2, -1]])
225+
x = paddle.to_tensor(np_x, dtype="int64")
226+
y = paddle.to_tensor(np_y, dtype="int64")
227+
z = x % y
228+
z_expected = np.array([[1, 0], [-1, 0]])
229+
self.assertEqual(np.allclose(z_expected, z.numpy()), True)
230+
223231

224232
if __name__ == '__main__':
225233
unittest.main()

0 commit comments

Comments
 (0)