Skip to content

Commit 7ce06c8

Browse files
authored
Merge pull request #5571 from sweetsky0901/my_maxout_op
Add maxout operator.
2 parents d5be1d4 + 9cb2ff6 commit 7ce06c8

File tree

9 files changed

+541
-0
lines changed

9 files changed

+541
-0
lines changed

paddle/operators/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ set(DEPS_OPS
184184
sequence_softmax_op
185185
sum_op
186186
pool_op
187+
maxout_op
187188
pool_with_index_op
188189
conv_op
189190
conv_transpose_op
@@ -210,6 +211,7 @@ op_library(sgd_op DEPS selected_rows_functor)
210211
op_library(adagrad_op DEPS selected_rows_functor)
211212
op_library(conv_op DEPS vol2col)
212213
op_library(pool_op DEPS pooling)
214+
op_library(maxout_op DEPS maxouting)
213215
op_library(pool_with_index_op DEPS pooling)
214216
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
215217
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)

paddle/operators/math/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ if(WITH_GPU)
1414
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
1515
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
1616
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
17+
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
1718
else()
1819
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
1920
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -26,6 +27,7 @@ else()
2627
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
2728
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
2829
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
30+
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
2931
endif()
3032

3133
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)

paddle/operators/math/maxouting.cc

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/maxouting.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace math {
20+
21+
// All tensors are in NCHW format, and the groups must be greater than 1
22+
template <typename T>
23+
class MaxOutFunctor<platform::CPUPlace, T> {
24+
public:
25+
void operator()(const platform::DeviceContext& context,
26+
const framework::Tensor& input,
27+
framework::Tensor * output,
28+
int groups) {
29+
const int batch_size = input.dims()[0];
30+
const int input_height = input.dims()[2];
31+
const int input_width = input.dims()[3];
32+
const int output_channels = output->dims()[1];
33+
int fea_size = input_height * input_width;
34+
// c_size means the output size of each sample
35+
int c_size = fea_size * output_channels;
36+
const T* input_data = input.data<T>();
37+
T* output_data = output->mutable_data<T>(context.GetPlace());
38+
39+
for (int i = 0; i < batch_size; ++i) {
40+
int new_bindex = c_size * i;
41+
for (int c = 0; c < output_channels; ++c) {
42+
int new_cindex = fea_size * c;
43+
for (int f = 0; f < fea_size; ++f) {
44+
T ele = static_cast<T>(-FLT_MAX);
45+
for (int ph = 0; ph < groups; ++ph) {
46+
T x = input_data[(new_bindex + new_cindex) * groups
47+
+ ph * fea_size + f];
48+
ele = ele > x ? ele : x;
49+
}
50+
output_data[(new_bindex+new_cindex+f)] = ele;
51+
}
52+
}
53+
}
54+
}
55+
};
56+
57+
58+
59+
template <class T>
60+
class MaxOutGradFunctor<platform::CPUPlace, T> {
61+
public:
62+
void operator()(const platform::DeviceContext& context,
63+
const framework::Tensor& input,
64+
framework::Tensor * input_grad,
65+
const framework::Tensor& output,
66+
const framework::Tensor& output_grad,
67+
int groups) {
68+
const int batch_size = input.dims()[0];
69+
const int input_height = input.dims()[2];
70+
const int input_width = input.dims()[3];
71+
const int output_channels = output.dims()[1];
72+
int fea_size = input_height * input_width;
73+
const T* input_data = input.data<T>();
74+
const T* output_data = output.data<T>();
75+
const T* output_grad_data = output_grad.data<T>();
76+
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
77+
78+
for (int i = 0; i < batch_size; ++i) {
79+
int blen = fea_size * output_channels * i;
80+
for (int c = 0; c < output_channels; ++c) {
81+
int clen = fea_size * c;
82+
for (int f = 0; f < fea_size; ++f) {
83+
int input_idx0 = (blen + clen) * groups + f;
84+
bool continue_match = true;
85+
int output_idx = blen + clen + f;
86+
for (int g = 0; g < groups && continue_match; ++g) {
87+
int input_idx = input_idx0 + fea_size * g;
88+
if (input_data[input_idx] == output_data[output_idx]) {
89+
input_grad_data[input_idx] += output_grad_data[output_idx];
90+
continue_match = false;
91+
}
92+
}
93+
}
94+
}
95+
}
96+
}
97+
};
98+
99+
template class MaxOutGradFunctor<platform::CPUPlace, float>;
100+
template class MaxOutGradFunctor<platform::CPUPlace, double>;
101+
template class MaxOutFunctor<platform::CPUPlace, float>;
102+
template class MaxOutFunctor<platform::CPUPlace, double>;
103+
104+
} // namespace math
105+
} // namespace operators
106+
} // namespace paddle

paddle/operators/math/maxouting.cu

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/maxouting.h"
16+
#include "paddle/platform/cuda_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
template <typename T>
23+
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
24+
const int channels,
25+
const int input_height, const int input_width,
26+
int groups, T* output_data ) {
27+
const int size = input_height * input_width * channels / groups;
28+
const int feat_len = input_height * input_width;
29+
int index = blockIdx.x * blockDim.x + threadIdx.x;
30+
int offset = blockDim.x * gridDim.x;
31+
for (int i = index; i < nthreads; i += offset) {
32+
int batch_idx = i / size;
33+
int batch_offset = i % size;
34+
int channel_idx = batch_offset / feat_len;
35+
int feat_idx = batch_offset % feat_len;
36+
int data_idx =
37+
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
38+
T ele = static_cast<T>(-FLT_MAX);
39+
for (int g = 0; g < groups; ++g) {
40+
T x = input_data[data_idx + g * feat_len];
41+
ele = ele > x ? ele : x;
42+
}
43+
output_data[i] = ele;
44+
}
45+
}
46+
template <typename T>
47+
__global__ void KernelMaxoutGrad(
48+
const int nthreads, const T* input_data, const T* output_data,
49+
const T* output_grad, T* input_grad, const int channels,
50+
const int input_height, const int input_width, int groups) {
51+
const int size = input_height * input_width * channels / groups;
52+
const int feat_len = input_height * input_width;
53+
int index = blockIdx.x * blockDim.x + threadIdx.x;
54+
int offset = blockDim.x * gridDim.x;
55+
for (int i = index; i < nthreads; i += offset) {
56+
int batch_idx = i / size;
57+
int batch_offset = i % size;
58+
int channel_idx = batch_offset / feat_len;
59+
int feat_idx = batch_offset % feat_len;
60+
int data_idx =
61+
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
62+
int max_index = -1;
63+
bool continue_match = true;
64+
for (int g = 0; g < groups && continue_match; ++g) {
65+
if (input_data[data_idx + g * feat_len] == output_data[i]) {
66+
max_index = data_idx + g * feat_len;
67+
continue_match = false;
68+
break;
69+
}
70+
}
71+
if (max_index != -1) {
72+
input_grad[max_index] += output_grad[index];
73+
}
74+
}
75+
}
76+
/*
77+
* All tensors are in NCHW format.
78+
*/
79+
template <typename T>
80+
class MaxOutFunctor<platform::GPUPlace, T> {
81+
public:
82+
void operator()(const platform::DeviceContext& context,
83+
const framework::Tensor& input, framework::Tensor * output,
84+
int groups) {
85+
const int batch_size = input.dims()[0];
86+
const int input_channels = input.dims()[1];
87+
const int input_height = input.dims()[2];
88+
const int input_width = input.dims()[3];
89+
const int output_channels = output->dims()[1];
90+
const int output_height = output->dims()[2];
91+
const int output_width = output->dims()[3];
92+
93+
const T* input_data = input.data<T>();
94+
T* output_data = output->mutable_data<T>(context.GetPlace());
95+
int nthreads = output->numel();
96+
int blocks = (nthreads + 1024 - 1) / 1024;
97+
dim3 threads(1024, 1);
98+
dim3 grid(blocks, 1);
99+
100+
KernelMaxOut<
101+
T><<<grid, threads, 0,
102+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
103+
.stream()>>>(nthreads, input_data, input_channels,
104+
input_height, input_width, groups,
105+
output_data);
106+
}
107+
};
108+
/*
109+
* All tensors are in NCHW format.
110+
*/
111+
template <typename T>
112+
class MaxOutGradFunctor<platform::GPUPlace, T> {
113+
public:
114+
void operator()(const platform::DeviceContext& context,
115+
const framework::Tensor& input,
116+
framework::Tensor * input_grad,
117+
const framework::Tensor& output,
118+
const framework::Tensor& output_grad,
119+
int groups) {
120+
const int batch_size = input.dims()[0];
121+
const int input_channels = input.dims()[1];
122+
const int input_height = input.dims()[2];
123+
const int input_width = input.dims()[3];
124+
const int output_channels = output.dims()[1];
125+
const int output_height = output.dims()[2];
126+
const int output_width = output.dims()[3];
127+
128+
const T* input_data = input.data<T>();
129+
const T* output_data = output.data<T>();
130+
const T* output_grad_data = output_grad.data<T>();
131+
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
132+
int nthreads = output.numel();
133+
int blocks = (nthreads + 1024 - 1) / 1024;
134+
dim3 threads(1024, 1);
135+
dim3 grid(blocks, 1);
136+
137+
KernelMaxoutGrad<
138+
T><<<grid, threads, 0,
139+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
140+
.stream()>>>(
141+
nthreads, input_data, output_data, output_grad_data, input_grad_data,
142+
input_channels, input_height, input_width, groups);
143+
}
144+
};
145+
146+
template class MaxOutGradFunctor<platform::GPUPlace, float>;
147+
template class MaxOutGradFunctor<platform::GPUPlace, double>;
148+
149+
template class MaxOutFunctor<platform::GPUPlace, float>;
150+
template class MaxOutFunctor<platform::GPUPlace, double>;
151+
152+
} // namespace math
153+
} // namespace operators
154+
} // namespace paddle

paddle/operators/math/maxouting.h

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/tensor.h"
17+
#include "paddle/platform/device_context.h"
18+
#include "paddle/platform/hostdevice.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
#define FLT_MAX \
25+
__FLT_MAX__
26+
27+
template <typename Place, typename T>
28+
29+
class MaxOutFunctor {
30+
public:
31+
void operator()(const platform::DeviceContext& context,
32+
const framework::Tensor& input, framework::Tensor * output,
33+
int groups);
34+
};
35+
36+
template <typename Place, class T>
37+
class MaxOutGradFunctor {
38+
public:
39+
void operator()(const platform::DeviceContext& context,
40+
const framework::Tensor& input,
41+
framework::Tensor * input_grad,
42+
const framework::Tensor& output,
43+
const framework::Tensor& output_grad, int groups);
44+
};
45+
} // namespace math
46+
} // namespace operators
47+
} // namespace paddle

0 commit comments

Comments
 (0)