-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add maxout operator. #5571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add maxout operator. #5571
Changes from 4 commits
4a428c8
058bdd3
784fd82
6c7e136
fe1e16b
ab9c71d
bd773b9
494edc6
bb1be5d
9954496
f57cd1e
f319fb1
8d9babf
3ef776e
5802880
63f8c5f
a6a01c1
4c113cc
25d76bc
2d7a652
13d39ea
c645d06
52f2366
76fc1a8
6ac4237
4e5c989
350cc61
3fbff1e
95cbbd7
04fd989
9cb2ff6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
/* | ||
* All tensors are in NCHW format. | ||
* Ksize, strides, paddings are two elements. These two elements represent | ||
* height and width, respectively. | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // All tensors are in NCHW format, and the groups must be greater than 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
template <typename MaxOutProcess, typename T> | ||
class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor& output, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int groups, int num_channels, MaxOutProcess maxout_process) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reorder parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个顺序,把num_channel去掉了,其他的位置暂时没有调整,整理来说前面是input 后面是output,但是groups属于一个全属性 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is better that putting all input-only parameters before any output parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 更里面的调了,这个暂时不调了,不然要改的地方也很多,收益并不大。这里不是严格按照输入输出看的,input和output更针对矩阵 |
||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = num_channels/groups; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The const int output_channels = output.dims()[1];
const int group = input.dims()[1] / output_channels; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. groups是最开始传进来的,然后output是用这个算出来的,所以这个groups我就一直带着了 |
||
|
||
int fea_size = input_height * input_width; | ||
int c_size = fea_size * output_channels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. c_size -> out_size ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
const T* input_data = input.data<T>(); | ||
T* output_data = output.mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; i++) { | ||
int new_bindex = c_size * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int new_cindex = fea_size * c; | ||
for (int f = 0; f < fea_size; f++) { | ||
T ele = maxout_process.initial(); | ||
for (int ph = 0; ph < groups; ++ph) { | ||
maxout_process.compute(ele, | ||
input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); | ||
} | ||
maxout_process.finalize(ele, (static_cast<T>(groups))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 前两个,我还是用了,把finalize去掉了,的确没有用 |
||
output_data[(new_bindex+new_cindex+f)] = ele; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议去掉maxout_process, 这里直接比大小~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
|
||
|
||
template <class T> | ||
class MaxOutGradFunctor<platform::CPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor& input_grad, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups, int num_channels) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. Function_Parameter_Ordering There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个顺序,把num_channel去掉了,其他的位置暂时没有调整,整理来说前面是input 后面是output,但是groups属于一个全属性 |
||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = num_channels / groups; | ||
|
||
int fea_size = input_height * input_width; | ||
|
||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; i++) { | ||
int blen = fea_size * output_channels * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int clen = fea_size * c; | ||
for (int f = 0; f < fea_size; f++) { | ||
int input_idx = 0; | ||
bool stop = false; | ||
int output_idx = blen + clen + f; | ||
for (int g = 0; g < groups && !stop; g++) { | ||
input_idx = (blen + clen) * groups + fea_size * g + f; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了减少多余的计算:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
input_grad_data[input_idx] = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove line 88. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 初始化为一个值,为什么需要去掉,有时候内存里往往有脏数据呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你可在循环外面像这样初始化 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input_idx在for循环里是变化的。除非在外面再memset了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if (input_data[input_idx] == output_data[output_idx]) { | ||
input_grad_data[input_idx] += output_grad_data[output_idx]; | ||
stop = true; | ||
} else { | ||
input_grad_data[input_idx] = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::CPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::CPUPlace, double>; | ||
template class MaxOutFunctor<platform::CPUPlace, | ||
paddle::operators::math::MaxOut<float>, float>; | ||
template class MaxOutFunctor<platform::CPUPlace, | ||
paddle::operators::math::MaxOut<double>, double>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
#include "paddle/platform/cuda_helper.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
template <typename MaxOutProcess, typename T> | ||
__global__ void KernelMaxOut(const int nthreads, const T* input_data, | ||
T* output_data, const int channels, | ||
const int input_height, const int input_width, | ||
int groups, MaxOutProcess maxout_process) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please note the order of parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int size = input_height * input_width * channels / groups; | ||
int featLen = input_height * input_width; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 30行的格式要换下,以减少重复计算:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
index += blockDim.x * gridDim.x) { | ||
int batch_idx = index / size; | ||
int i = index % size; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int channel_idx = i / featLen; | ||
int feat_idx = i % featLen; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * featLen) * groups + feat_idx; | ||
T ele = maxout_process.initial(); | ||
for (int g = 0; g < groups; g++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other parts of code also have the similar using, please correct one by one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
maxout_process.compute(ele, input_data[data_idx + g * featLen]); | ||
} | ||
maxout_process.finalize(ele, (static_cast<T>(groups))); | ||
output_data[index] = ele; | ||
} | ||
} | ||
template <typename T> | ||
__global__ void KernelMaxoutGrad( | ||
const int nthreads, const T* input_data, const T* output_data, | ||
const T* output_grad, T* input_grad, const int channels, | ||
const int input_height, const int input_width, int groups) { | ||
int size = input_height * input_width * channels / groups; | ||
int featLen = input_height * input_width; | ||
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
index += blockDim.x * gridDim.x) { | ||
int batch_idx = index / size; | ||
int i = index % size; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int channel_idx = i / featLen; | ||
int feat_idx = i % featLen; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * featLen) * groups + feat_idx; | ||
int maxIndex = -1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maxIndex -> max_index There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
bool stop = false; | ||
for (int g = 0; g < groups && !stop; g++) { | ||
if (input_data[data_idx + g * featLen] == output_data[index]) { | ||
maxIndex = data_idx + g * featLen; | ||
stop = true; | ||
} | ||
} | ||
if (maxIndex != -1) { | ||
// atomic add | ||
platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); | ||
} | ||
} | ||
} | ||
/* | ||
* All tensors are in NCHW format. | ||
* Ksize, strides, paddings are two elements. These two elements represent | ||
* height and width, respectively. | ||
*/ | ||
template <typename MaxOutProcess, typename T> | ||
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor& output, | ||
int groups, int num_channels, | ||
MaxOutProcess maxout_process) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same with above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = num_channels / groups; | ||
const int output_height = output.dims()[2]; | ||
const int output_width = output.dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
T* output_data = output.mutable_data<T>(context.GetPlace()); | ||
|
||
int nthreads = batch_size * output_channels * output_height * output_width; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int nthreads = output.numel(); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxOut< | ||
MaxOutProcess, | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>(nthreads, input_data, output_data, input_channels, | ||
input_height, input_width, groups, | ||
maxout_process); | ||
} | ||
}; | ||
/* | ||
* All tensors are in NCHW format. | ||
* Ksize, strides, paddings are two elements. These two elements represent | ||
* height and width, respectively. | ||
*/ | ||
template <typename T> | ||
class MaxOutGradFunctor<platform::GPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor& input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups, int num_channels) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output.dims()[1]; | ||
const int output_height = output.dims()[2]; | ||
const int output_width = output.dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); | ||
|
||
int nthreads = batch_size * output_channels * output_height * output_width; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int nthreads = output.numel(); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxoutGrad< | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>( | ||
nthreads, input_data, output_data, output_grad_data, input_grad_data, | ||
input_channels, input_height, input_width, groups); | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::GPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::GPUPlace, double>; | ||
|
||
template class MaxOutFunctor<platform::GPUPlace, | ||
paddle::operators::math::MaxOut<float>, float>; | ||
template class MaxOutFunctor<platform::GPUPlace, | ||
paddle::operators::math::MaxOut<double>, double>; | ||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments are not right, same as following.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done