@@ -31,6 +31,15 @@ struct ModFunctor {
31
31
}
32
32
};
33
33
34
+ template <typename T>
35
+ struct InverseModFunctor {
36
+ inline HOSTDEVICE T operator ()(T a, T b) const {
37
+ T res = b % a;
38
+ if ((res != 0 ) && ((res < 0 ) != (a < 0 ))) res += a;
39
+ return res;
40
+ }
41
+ };
42
+
34
43
template <typename T>
35
44
struct ModFunctorFP {
36
45
inline HOSTDEVICE T operator ()(T a, T b) const {
@@ -40,22 +49,45 @@ struct ModFunctorFP {
40
49
}
41
50
};
42
51
52
+ template <typename T>
53
+ struct InverseModFunctorFP {
54
+ inline HOSTDEVICE T operator ()(T a, T b) const {
55
+ T res = fmod (b, a);
56
+ if ((res != 0 ) && ((a < 0 ) != (res < 0 ))) res += a;
57
+ return res;
58
+ }
59
+ };
60
+
43
61
template <typename DeviceContext, typename T>
44
62
void elementwise_mod (const framework::ExecutionContext &ctx,
45
63
const framework::Tensor *x, const framework::Tensor *y,
46
64
framework::Tensor *z) {
47
65
int axis = ctx.Attr <int >(" axis" );
48
- ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
49
- ModFunctor<T>(), z);
66
+ auto x_dims = x->dims ();
67
+ auto y_dims = y->dims ();
68
+ if (x_dims.size () >= y_dims.size ()) {
69
+ ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
70
+ ModFunctor<T>(), z);
71
+ } else {
72
+ ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
73
+ ctx, x, y, axis, InverseModFunctor<T>(), z);
74
+ }
50
75
}
51
76
52
77
template <typename DeviceContext, typename T>
53
78
void elementwise_mod_fp (const framework::ExecutionContext &ctx,
54
79
const framework::Tensor *x, const framework::Tensor *y,
55
80
framework::Tensor *z) {
56
81
int axis = ctx.Attr <int >(" axis" );
57
- ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis,
58
- ModFunctorFP<T>(), z);
82
+ auto x_dims = x->dims ();
83
+ auto y_dims = y->dims ();
84
+ if (x_dims.size () >= y_dims.size ()) {
85
+ ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(
86
+ ctx, x, y, axis, ModFunctorFP<T>(), z);
87
+ } else {
88
+ ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
89
+ ctx, x, y, axis, InverseModFunctorFP<T>(), z);
90
+ }
59
91
}
60
92
61
93
template <typename DeviceContext, typename T>
0 commit comments