@@ -19,13 +19,6 @@ limitations under the License. */
19
19
namespace paddle {
20
20
namespace operators {
21
21
22
- template <typename T, int MajorType = Eigen::RowMajor,
23
- typename IndexType = Eigen::DenseIndex>
24
- using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
25
- template <typename T, int MajorType = Eigen::RowMajor,
26
- typename IndexType = Eigen::DenseIndex>
27
- using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
28
-
29
22
template <typename DeviceContext, typename T, typename AttrType = T>
30
23
class NormKernel : public framework ::OpKernel<T> {
31
24
public:
@@ -42,29 +35,37 @@ class NormKernel : public framework::OpKernel<T> {
42
35
int fea_len = height * width;
43
36
auto * place =
44
37
context.template device_context <DeviceContext>().eigen_device ();
45
- auto x = EigenMatrix<T>::From (
46
- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
38
+ auto x =
39
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
40
+ *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
47
41
// get square
48
42
framework::Tensor x_square;
49
43
x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
50
- auto x_square_eigen = EigenMatrix<T>::From (
51
- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
44
+ auto x_square_eigen =
45
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
46
+ x_square, framework::make_ddim ({batch_size, fea_len * channels}));
52
47
x_square_eigen.device (*place) = x.square ();
53
- auto scale_eigen = EigenVector<T>::Flatten (*scale);
48
+ auto scale_eigen =
49
+ framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
50
+ *scale);
54
51
for (int n = 0 ; n < batch_size; ++n) {
55
52
framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
56
- auto in_x_batch_eigen = EigenMatrix<T>::From (
57
- in_x_batch, framework::make_ddim ({channels, fea_len}));
53
+ auto in_x_batch_eigen =
54
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
55
+ in_x_batch, framework::make_ddim ({channels, fea_len}));
58
56
framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
59
- auto x_square_batch_eigen = EigenMatrix<T>::From (
60
- x_square_batch, framework::make_ddim ({channels, fea_len}));
57
+ auto x_square_batch_eigen =
58
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
59
+ x_square_batch, framework::make_ddim ({channels, fea_len}));
61
60
framework::Tensor out_batch = out->Slice (n, n + 1 );
62
- auto out_batch_eigen = EigenMatrix<T>::From (
63
- out_batch, framework::make_ddim ({channels, fea_len}));
61
+ auto out_batch_eigen =
62
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
63
+ out_batch, framework::make_ddim ({channels, fea_len}));
64
64
framework::Tensor tmp_tensor;
65
65
tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
66
66
context.GetPlace ());
67
- auto tmp = EigenVector<T>::Flatten (tmp_tensor);
67
+ auto tmp = framework::EigenVector<T, Eigen::RowMajor,
68
+ Eigen::DenseIndex>::Flatten (tmp_tensor);
68
69
// get colsum and sqrt , inverse
69
70
auto dim = Eigen::array<int , 1 >({{0 }});
70
71
tmp.device (*place) = x_square_batch_eigen.sum (dim);
@@ -102,40 +103,52 @@ class NormGradKernel : public framework::OpKernel<T> {
102
103
auto * place =
103
104
context.template device_context <DeviceContext>().eigen_device ();
104
105
105
- auto scale_eigen = EigenVector<T>::Flatten (*scale);
106
- auto x = EigenMatrix<T>::From (
107
- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
106
+ auto scale_eigen =
107
+ framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
108
+ *scale);
109
+ auto x =
110
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
111
+ *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
108
112
// get square
109
113
framework::Tensor x_square;
110
114
x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
111
- auto x_square_eigen = EigenMatrix<T>::From (
112
- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
115
+ auto x_square_eigen =
116
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
117
+ x_square, framework::make_ddim ({batch_size, fea_len * channels}));
113
118
x_square_eigen.device (*place) = x.square ();
114
119
115
120
for (int n = 0 ; n < batch_size; ++n) {
116
121
framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
117
- auto in_x_batch_eigen = EigenMatrix<T>::From (
118
- in_x_batch, framework::make_ddim ({channels, fea_len}));
122
+ auto in_x_batch_eigen =
123
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
124
+ in_x_batch, framework::make_ddim ({channels, fea_len}));
119
125
framework::Tensor in_g_batch = in_x_grad->Slice (n, n + 1 );
120
- auto in_g_batch_eigen = EigenMatrix<T>::From (
121
- in_g_batch, framework::make_ddim ({channels, fea_len}));
126
+ auto in_g_batch_eigen =
127
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
128
+ in_g_batch, framework::make_ddim ({channels, fea_len}));
122
129
framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
123
- auto x_square_batch_eigen = EigenMatrix<T>::From (
124
- x_square_batch, framework::make_ddim ({channels, fea_len}));
130
+ auto x_square_batch_eigen =
131
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
132
+ x_square_batch, framework::make_ddim ({channels, fea_len}));
125
133
framework::Tensor outg_batch = out_grad->Slice (n, n + 1 );
126
- auto outg_batch_eigen = EigenMatrix<T>::From (
127
- outg_batch, framework::make_ddim ({channels, fea_len}));
134
+ auto outg_batch_eigen =
135
+ framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
136
+ outg_batch, framework::make_ddim ({channels, fea_len}));
128
137
129
138
framework::Tensor tmp_tensor;
130
139
tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
131
140
context.GetPlace ());
132
- auto tmp_eigen = EigenVector<T>::Flatten (tmp_tensor);
141
+ auto tmp_eigen =
142
+ framework::EigenVector<T, Eigen::RowMajor,
143
+ Eigen::DenseIndex>::Flatten (tmp_tensor);
133
144
auto dim = Eigen::array<int , 1 >({{0 }});
134
145
tmp_eigen.device (*place) = (in_x_batch_eigen * outg_batch_eigen).sum (dim);
135
146
framework::Tensor norm_tmp_tensor;
136
147
norm_tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
137
148
context.GetPlace ());
138
- auto norm_tmp_eigen = EigenVector<T>::Flatten (norm_tmp_tensor);
149
+ auto norm_tmp_eigen =
150
+ framework::EigenVector<T, Eigen::RowMajor,
151
+ Eigen::DenseIndex>::Flatten (norm_tmp_tensor);
139
152
norm_tmp_eigen.device (*place) =
140
153
(x_square_batch_eigen.sum (dim) + epsilon).sqrt ();
141
154
Eigen::array<int , 2 > broadcast_dim_col;
0 commit comments