-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add maxout operator. #5571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sweetsky0901
merged 31 commits into
PaddlePaddle:develop
from
sweetsky0901:my_maxout_op
Nov 20, 2017
Merged
Add maxout operator. #5571
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
4a428c8
this for maxout op new add
058bdd3
this for maxout op new add
784fd82
resolve conflicts
6c7e136
Merge branch 'develop' into my_maxout_op
sweetsky0901 fe1e16b
Merge branch 'develop' into my_maxout_op
sweetsky0901 ab9c71d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bd773b9
modify for maxoutop code review
494edc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bb1be5d
merge cmakelist
9954496
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
f57cd1e
del a err comments
f319fb1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
8d9babf
maxout code review 2nd
3ef776e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
5802880
update maxoutop for code review 3
63f8c5f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
a6a01c1
add test_maxout_op framework to fluis
4c113cc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
25d76bc
modify for add a space in maxout op
2d7a652
del framework test_maxout_op
13d39ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
c645d06
add a space + *
52f2366
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
76fc1a8
for code review 4
6ac4237
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
4e5c989
rename back
sweetsky0901 350cc61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
3fbff1e
for code review 5
95cbbd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
04fd989
for code review 6
9cb2ff6
del num_channels
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
// All tensors are in NCHW format, and the groups must be greater than 1 | ||
template <typename T> | ||
class MaxOutFunctor<platform::CPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor * output, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output->dims()[1]; | ||
int fea_size = input_height * input_width; | ||
// c_size means the output size of each sample | ||
int c_size = fea_size * output_channels; | ||
const T* input_data = input.data<T>(); | ||
T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; ++i) { | ||
int new_bindex = c_size * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int new_cindex = fea_size * c; | ||
for (int f = 0; f < fea_size; ++f) { | ||
T ele = static_cast<T>(-FLT_MAX); | ||
for (int ph = 0; ph < groups; ++ph) { | ||
T x = input_data[(new_bindex + new_cindex) * groups | ||
+ ph * fea_size + f]; | ||
ele = ele > x ? ele : x; | ||
} | ||
output_data[(new_bindex+new_cindex+f)] = ele; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
|
||
|
||
template <class T> | ||
class MaxOutGradFunctor<platform::CPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor * input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output.dims()[1]; | ||
int fea_size = input_height * input_width; | ||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; ++i) { | ||
int blen = fea_size * output_channels * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int clen = fea_size * c; | ||
for (int f = 0; f < fea_size; ++f) { | ||
int input_idx0 = (blen + clen) * groups + f; | ||
bool continue_match = true; | ||
int output_idx = blen + clen + f; | ||
for (int g = 0; g < groups && continue_match; ++g) { | ||
int input_idx = input_idx0 + fea_size * g; | ||
if (input_data[input_idx] == output_data[output_idx]) { | ||
input_grad_data[input_idx] += output_grad_data[output_idx]; | ||
continue_match = false; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::CPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::CPUPlace, double>; | ||
template class MaxOutFunctor<platform::CPUPlace, float>; | ||
template class MaxOutFunctor<platform::CPUPlace, double>; | ||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
#include "paddle/platform/cuda_helper.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
template <typename T> | ||
__global__ void KernelMaxOut(const int nthreads, const T* input_data, | ||
const int channels, | ||
const int input_height, const int input_width, | ||
int groups, T* output_data ) { | ||
const int size = input_height * input_width * channels / groups; | ||
const int feat_len = input_height * input_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int batch_idx = i / size; | ||
int batch_offset = i % size; | ||
int channel_idx = batch_offset / feat_len; | ||
int feat_idx = batch_offset % feat_len; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; | ||
T ele = static_cast<T>(-FLT_MAX); | ||
for (int g = 0; g < groups; ++g) { | ||
T x = input_data[data_idx + g * feat_len]; | ||
ele = ele > x ? ele : x; | ||
} | ||
output_data[i] = ele; | ||
} | ||
} | ||
template <typename T> | ||
__global__ void KernelMaxoutGrad( | ||
const int nthreads, const T* input_data, const T* output_data, | ||
const T* output_grad, T* input_grad, const int channels, | ||
const int input_height, const int input_width, int groups) { | ||
const int size = input_height * input_width * channels / groups; | ||
const int feat_len = input_height * input_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int batch_idx = i / size; | ||
int batch_offset = i % size; | ||
int channel_idx = batch_offset / feat_len; | ||
int feat_idx = batch_offset % feat_len; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; | ||
int max_index = -1; | ||
bool continue_match = true; | ||
for (int g = 0; g < groups && continue_match; ++g) { | ||
if (input_data[data_idx + g * feat_len] == output_data[i]) { | ||
max_index = data_idx + g * feat_len; | ||
continue_match = false; | ||
break; | ||
} | ||
} | ||
if (max_index != -1) { | ||
input_grad[max_index] += output_grad[index]; | ||
} | ||
} | ||
} | ||
/* | ||
* All tensors are in NCHW format. | ||
*/ | ||
template <typename T> | ||
class MaxOutFunctor<platform::GPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor * output, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output->dims()[1]; | ||
const int output_height = output->dims()[2]; | ||
const int output_width = output->dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
int nthreads = output->numel(); | ||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxOut< | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>(nthreads, input_data, input_channels, | ||
input_height, input_width, groups, | ||
output_data); | ||
} | ||
}; | ||
/* | ||
* All tensors are in NCHW format. | ||
*/ | ||
template <typename T> | ||
class MaxOutGradFunctor<platform::GPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor * input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output.dims()[1]; | ||
const int output_height = output.dims()[2]; | ||
const int output_width = output.dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
int nthreads = output.numel(); | ||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxoutGrad< | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>( | ||
nthreads, input_data, output_data, output_grad_data, input_grad_data, | ||
input_channels, input_height, input_width, groups); | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::GPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::GPUPlace, double>; | ||
|
||
template class MaxOutFunctor<platform::GPUPlace, float>; | ||
template class MaxOutFunctor<platform::GPUPlace, double>; | ||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include "paddle/framework/tensor.h" | ||
#include "paddle/platform/device_context.h" | ||
#include "paddle/platform/hostdevice.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
#define FLT_MAX \ | ||
__FLT_MAX__ | ||
|
||
template <typename Place, typename T> | ||
|
||
class MaxOutFunctor { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor * output, | ||
int groups); | ||
}; | ||
|
||
template <typename Place, class T> | ||
class MaxOutGradFunctor { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor * input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, int groups); | ||
}; | ||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check whether
output_channels
andinput.dims()[1] / groups
are equal.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在外面是这么初始化出来的,所以才没有再检查