Skip to content

Add maxout operator. #5571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Nov 20, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4a428c8
this for maxout op new add
Nov 11, 2017
058bdd3
this for maxout op new add
Nov 11, 2017
784fd82
resolve conflicts
Nov 11, 2017
6c7e136
Merge branch 'develop' into my_maxout_op
sweetsky0901 Nov 13, 2017
fe1e16b
Merge branch 'develop' into my_maxout_op
sweetsky0901 Nov 13, 2017
ab9c71d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 13, 2017
bd773b9
modify for maxoutop code review
Nov 14, 2017
494edc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
bb1be5d
merge cmakelist
Nov 14, 2017
9954496
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
f57cd1e
del a err comments
Nov 14, 2017
f319fb1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
8d9babf
maxout code review 2nd
Nov 15, 2017
3ef776e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 15, 2017
5802880
update maxoutop for code review 3
Nov 19, 2017
63f8c5f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 19, 2017
a6a01c1
add test_maxout_op framework to fluis
Nov 19, 2017
4c113cc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
25d76bc
modify for add a space in maxout op
Nov 20, 2017
2d7a652
del framework test_maxout_op
Nov 20, 2017
13d39ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
c645d06
add a space + *
Nov 20, 2017
52f2366
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
76fc1a8
for code review 4
Nov 20, 2017
6ac4237
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
4e5c989
rename back
sweetsky0901 Nov 20, 2017
350cc61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
3fbff1e
for code review 5
Nov 20, 2017
95cbbd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
04fd989
for code review 6
Nov 20, 2017
9cb2ff6
del num_channels
Nov 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ set(DEPS_OPS
softmax_with_cross_entropy_op
sum_op
pool_op
maxout_op
pool_with_index_op
conv_op
lstm_op
Expand All @@ -182,6 +183,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(conv_op DEPS vol2col)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ if(WITH_GPU)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
Expand All @@ -20,6 +21,7 @@ else()
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
Expand Down
117 changes: 117 additions & 0 deletions paddle/operators/math/maxouting.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"

namespace paddle {
namespace operators {
namespace math {

/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments are not right, same as following.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// All tensors are in NCHW format, and the groups must be greater than 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::Tensor* for output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int groups, int num_channels, MaxOutProcess maxout_process) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the num_channels can be got from the input tensor. There is no need to pass this argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reorder parameter.
Function_Parameter_Ordering

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个顺序,把num_channel去掉了,其他的位置暂时没有调整,整理来说前面是input 后面是output,但是groups属于一个全属性

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better that putting all input-only parameters before any output parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

更里面的调了,这个暂时不调了,不然要改的地方也很多,收益并不大。这里不是严格按照输入输出看的,input和output更针对矩阵

const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels/groups;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output_channels can be got from the output tensor. Then the groups can be calculated from channel number of input and output.

const int output_channels = output.dims()[1];
const int group =  input.dims()[1] / output_channels;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups是最开始传进来的,然后output是用这个算出来的,所以这个groups我就一直带着了


int fea_size = input_height * input_width;
int c_size = fea_size * output_channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c_size -> out_size ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; i++) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; f++) {
T ele = maxout_process.initial();
for (int ph = 0; ph < groups; ++ph) {
maxout_process.compute(ele,
input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]);
}
maxout_process.finalize(ele, (static_cast<T>(groups)));
Copy link
Contributor

@qingqing01 qingqing01 Nov 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the MaxOutProcess maxout_process is not necessary. The implementation can be expanded here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前两个,我还是用了,把finalize去掉了,的确没有用

output_data[(new_bindex+new_cindex+f)] = ele;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉maxout_process, 这里直接比大小~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}
}
};



template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor& input_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::Tensor& input_grad -> framework::Tensor* input_grad, Maybe better to put the output as the last arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups, int num_channels) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same with MaxOutFunctor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个顺序,把num_channel去掉了,其他的位置暂时没有调整,整理来说前面是input 后面是output,但是groups属于一个全属性

const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels / groups;

int fea_size = input_height * input_width;

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; i++) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; f++) {
int input_idx = 0;
bool stop = false;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && !stop; g++) {
input_idx = (blen + clen) * groups + fea_size * g + f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了减少多余的计算:

  • 90行改成:int input_idx0 = (blen + clen) * groups + f;
    94行:int input_idx = input_idx0 + fea_size * g;
  • 91行改成bool continue = true,并对应修改93行和98行。这样93行每次判断,不用多做一次非操作。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

input_grad_data[input_idx] = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove line 88.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化为一个值,为什么需要去掉,有时候内存里往往有脏数据呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你可在循环外面像这样初始化

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_idx在for循环里是变化的。除非在外面再memset了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
stop = true;
} else {
input_grad_data[input_idx] = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove else{...}. You have set value in line 94.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}
}
}
}
};

template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<double>, double>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle::operators:: can be removed. Because MaxOutFunctor is in the paddle::operators namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


} // namespace math
} // namespace operators
} // namespace paddle
161 changes: 161 additions & 0 deletions paddle/operators/math/maxouting.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename MaxOutProcess, typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
T* output_data, const int channels,
const int input_height, const int input_width,
int groups, MaxOutProcess maxout_process) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note the order of parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size and featLen should be const int.
Please remain variable name convenient. featLen -> feat_len

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30行的格式要换下,以减少重复计算:

int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i  = index; i < nthreads; i += offset)
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

index += blockDim.x * gridDim.x) {
int batch_idx = index / size;
int i = index % size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not use i in there, especially in the loop.
You can use temp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int channel_idx = i / featLen;
int feat_idx = i % featLen;
int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
T ele = maxout_process.initial();
for (int g = 0; g < groups; g++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ++g is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other parts of code also have the similar using, please correct one by one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

maxout_process.compute(ele, input_data[data_idx + g * featLen]);
}
maxout_process.finalize(ele, (static_cast<T>(groups)));
output_data[index] = ele;
}
}
template <typename T>
__global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

index += blockDim.x * gridDim.x) {
int batch_idx = index / size;
int i = index % size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int channel_idx = i / featLen;
int feat_idx = i % featLen;
int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
int maxIndex = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxIndex -> max_index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

bool stop = false;
for (int g = 0; g < groups && !stop; g++) {
if (input_data[data_idx + g * featLen] == output_data[index]) {
maxIndex = data_idx + g * featLen;
stop = true;
}
}
if (maxIndex != -1) {
// atomic add
platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
}
}
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
int groups, int num_channels,
MaxOutProcess maxout_process) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same with above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels / groups;
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];

const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());

int nthreads = batch_size * output_channels * output_height * output_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int nthreads =  output.numel();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxOut<
MaxOutProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, input_channels,
input_height, input_width, groups,
maxout_process);
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups, int num_channels) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());

int nthreads = batch_size * output_channels * output_height * output_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int nthreads =  output.numel();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxoutGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
}
};

template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>;

template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<double>, double>;

} // namespace math
} // namespace operators
} // namespace paddle
Loading