Skip to content

Commit e35ef6f

Browse files
authored
Merge pull request #212 from cocodark/develop
fix #210 add pool& pool test fix #211 add executor for testing op
2 parents b811721 + 70c0597 commit e35ef6f

22 files changed

+714
-65
lines changed

src/framework/executor.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,28 @@ Executor<Dtype>::Executor(const Program<Dtype> p) : program_(p) {
3232
to_predict_program_ = program_.originProgram;
3333
}
3434

35-
const std::vector<std::shared_ptr<BlockDesc>> blocks =
36-
to_predict_program_->Blocks();
37-
for (int i = 0; i < blocks.size(); ++i) {
38-
std::shared_ptr<BlockDesc> block_desc = blocks[i];
39-
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
40-
for (int j = 0; j < ops.size(); ++j) {
41-
std::shared_ptr<OpDesc> op = ops[j];
42-
if (op->Type() == "conv2d" && op->Input("Input")[0] == "pixel") {
43-
Attribute strides_attr = op->GetAttrMap().at("strides");
44-
std::vector<int> stride = strides_attr.Get<std::vector<int>>();
45-
for (int k = 0; k < stride.size(); ++k) {
46-
}
47-
std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
48-
std::make_shared<operators::ConvOp<Dtype, float>>(
49-
op->Type(), op->GetInputs(), op->GetOutputs(),
50-
op->GetAttrMap(), program_.scope);
51-
ops_of_block_[*block_desc.get()].push_back(conv);
52-
}
53-
}
54-
}
35+
// const std::vector<std::shared_ptr<BlockDesc>> blocks =
36+
to_predict_program_->Blocks();
37+
// for (int i = 0; i < blocks.size(); ++i) {
38+
// std::shared_ptr<BlockDesc> block_desc = blocks[i];
39+
// std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
40+
// for (int j = 0; j < ops.size(); ++j) {
41+
// std::shared_ptr<OpDesc> op = ops[j];
42+
// if (op->Type() == "conv2d" && op->Input("Input")[0] ==
43+
// "pixel") {
44+
// Attribute strides_attr = op->GetAttrMap().at("strides");
45+
// std::vector<int> stride =
46+
// strides_attr.Get<std::vector<int>>(); for (int k = 0; k <
47+
// stride.size(); ++k) {
48+
// }
49+
// std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
50+
// std::make_shared<operators::ConvOp<Dtype, float>>(
51+
// op->Type(), op->GetInputs(), op->GetOutputs(),
52+
// op->GetAttrMap(), program_.scope);
53+
// ops_of_block_[*block_desc.get()].push_back(conv);
54+
// }
55+
// }
56+
// }
5557
}
5658

5759
template <typename Dtype>
@@ -82,7 +84,6 @@ void Executor<Dtype>::predict(const Tensor &t, int block_id) {
8284
to_predict_program_->Block(block_id);
8385
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
8486
auto op = ops_of_block_[*to_predict_block.get()][j];
85-
// std::cout << "开始run" << std::endl;
8687
op->Run();
8788
}
8889
}

src/framework/executor.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,18 @@ namespace framework {
3636

3737
template <typename Dtype> class Executor {
3838
public:
39+
Executor();
40+
3941
Executor(const Program<Dtype> p);
42+
4043
std::shared_ptr<Tensor> predict(Tensor &t);
4144

42-
private:
45+
public:
4346
const framework::Program<Dtype> program_;
4447
std::shared_ptr<ProgramDesc> to_predict_program_;
48+
4549
void predict(const Tensor &t, int block_id);
50+
4651
std::map<framework::BlockDesc,
4752
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
4853
ops_of_block_;

src/framework/operator.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ template <typename Dtype> class OperatorBase : PaddleMobileObject {
6060
const VariableNameMap &Outputs() const { return outputs_; }
6161
const std::string &Type() const { return type_; }
6262
const AttributeMap &Attrs() const { return attrs_; }
63-
void ClearVariables() const {
63+
void ClearVariables(const std::vector<std::string> &var_names) const {
6464
if (this->scope_) {
65-
this->scope_->EraseVars(this->inputs_.at("Filter"));
66-
this->scope_->EraseVars(this->inputs_.at("Input"));
65+
this->scope_->EraseVars(var_names);
6766
}
6867
}
6968

src/io.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
194194
framework::proto::BlockDesc block = program_desc_proto.blocks()[i];
195195
LOG(kLOG_DEBUG) << "block: " << block.idx();
196196
for (int j = 0; j < block.ops().size(); ++j) {
197+
if (j == 2) {
198+
break;
199+
}
197200
framework::proto::OpDesc op = block.ops()[j];
198201
LOG(kLOG_DEBUG1) << "op: " << op.type();
199202
for (int m = 0; m < op.inputs_size(); ++m) {

src/operators/conv_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class ConvOp : public framework::OperatorWithKernel<DeviceType> {
4040
void InferShape() const override;
4141

4242
void Run() const {
43-
operators::ConvKernel<DeviceType, T, ConvParam> kernel;
43+
operators::ConvKernel<DeviceType, T> kernel;
4444
kernel.Compute(param_);
45-
this->ClearVariables();
45+
this->ClearVariables({"Filter", "Input"});
4646
}
4747

4848
private:

src/operators/kernel/arm/conv_kernel.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ bool IsExpand(const std::vector<int64_t> &filter_dim,
3434
return !(filter_1 && strides_1 && padding_0 && dilation_1);
3535
}
3636

37-
template <>
38-
void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam &param) const {
37+
template <> void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
3938
LOG(kLOG_DEBUG) << param;
4039

4140
const Tensor *input = param.Input();
@@ -149,7 +148,7 @@ void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam &param) const {
149148
}
150149
}
151150

152-
template class ConvKernel<CPU, float, ConvParam>;
151+
template class ConvKernel<CPU, float>;
153152

154153
} // namespace operators
155154
} // namespace paddle_mobile
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
2+
Permission is hereby granted, free of charge, to any person obtaining a copy
3+
of this software and associated documentation files (the "Software"), to deal
4+
in the Software without restriction, including without limitation the rights
5+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6+
copies of the Software, and to permit persons to whom the Software is
7+
furnished to do so, subject to the following conditions:
8+
The above copyright notice and this permission notice shall be included in all
9+
copies or substantial portions of the Software.
10+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
16+
SOFTWARE.
17+
==============================================================================*/
18+
#include "common/log.h"
19+
#include <operators/kernel/pool_kernel.h>
20+
21+
namespace paddle_mobile {
22+
namespace operators {
23+
24+
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
25+
std::vector<int> strides, std::vector<int> paddings,
26+
const Tensor *in_x, Tensor *out) {
27+
if (pooling_type == "max") {
28+
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward;
29+
math::MaxPool<float> pool_process;
30+
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
31+
32+
} else if (pooling_type == "avg") {
33+
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward;
34+
math::AvgPool<float> pool_process;
35+
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
36+
}
37+
}
38+
39+
template <> void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
40+
const Tensor *in_x = param.Input();
41+
Tensor *out = param.Output();
42+
std::string pooling_type = param.PoolingType();
43+
44+
std::vector<int> ksize = param.Ksize();
45+
46+
std::vector<int> strides = param.Strides();
47+
48+
std::vector<int> paddings = param.Paddings();
49+
if (ksize.size() != 2) {
50+
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
51+
<< "Pool op only supports 2D and 3D input.";
52+
}
53+
54+
if (param.isGlobalPooling()) {
55+
for (size_t i = 0; i < ksize.size(); ++i) {
56+
paddings[i] = 0;
57+
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
58+
}
59+
}
60+
61+
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
62+
63+
// if (param.isGlobalPooling() || ksize[0] != ksize[1] ||
64+
// strides[0] != strides[1] || strides[1] != 2 ||
65+
// paddings[0] != paddings[1] || paddings[1] > 1) {
66+
// PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
67+
//
68+
// } else if (ksize[0] == 2) {
69+
//
70+
// } else if (ksize[0] == 3) {
71+
//
72+
// } else {
73+
// PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
74+
// }
75+
}
76+
} // namespace operators
77+
} // namespace paddle_mobile

src/operators/kernel/conv_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace operators {
2929

3030
using namespace framework;
3131

32-
template <typename DeviceType, typename T, typename P>
32+
template <typename DeviceType, typename T>
3333
class ConvKernel : public framework::OpKernelBase<DeviceType, ConvParam> {
3434
public:
3535
void Compute(const ConvParam &param) const;

src/operators/kernel/pool_kernel.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
2+
Permission is hereby granted, free of charge, to any person obtaining a copy
3+
of this software and associated documentation files (the "Software"), to deal
4+
in the Software without restriction, including without limitation the rights
5+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6+
copies of the Software, and to permit persons to whom the Software is
7+
furnished to do so, subject to the following conditions:
8+
The above copyright notice and this permission notice shall be included in all
9+
copies or substantial portions of the Software.
10+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
16+
SOFTWARE.
17+
==============================================================================*/
18+
#pragma once
19+
20+
#include "framework/operator.h"
21+
#include "operators/math/pooling.h"
22+
#include "operators/op_param.h"
23+
24+
namespace paddle_mobile {
25+
namespace operators {
26+
27+
using namespace framework;
28+
29+
template <typename DeviceType, typename T>
30+
class PoolKernel : public framework::OpKernelBase<DeviceType, PoolParam> {
31+
public:
32+
void Compute(const PoolParam &param) const;
33+
};
34+
} // namespace operators
35+
} // namespace paddle_mobile

src/operators/math/pool3x3.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
2+
Permission is hereby granted, free of charge, to any person obtaining a copy
3+
of this software and associated documentation files (the "Software"), to deal
4+
in the Software without restriction, including without limitation the rights
5+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6+
copies of the Software, and to permit persons to whom the Software is
7+
furnished to do so, subject to the following conditions:
8+
The above copyright notice and this permission notice shall be included in all
9+
copies or substantial portions of the Software.
10+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
16+
SOFTWARE.
17+
==============================================================================*/
18+
#pragma once
19+
20+
#if __ARM_NEON
21+
#include <arm_neon.h>
22+
#endif // __ARM_NEON
23+
24+
static void Pool3x3Max() {
25+
// todo impl with neon
26+
}
27+
28+
static void Pool3x3Avg() {
29+
// todo impl with neon
30+
}

src/operators/math/pool_2x2.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
2+
Permission is hereby granted, free of charge, to any person obtaining a copy
3+
of this software and associated documentation files (the "Software"), to deal
4+
in the Software without restriction, including without limitation the rights
5+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6+
copies of the Software, and to permit persons to whom the Software is
7+
furnished to do so, subject to the following conditions:
8+
The above copyright notice and this permission notice shall be included in all
9+
copies or substantial portions of the Software.
10+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
16+
SOFTWARE.
17+
==============================================================================*/
18+
#pragma once
19+
20+
#if __ARM_NEON
21+
#include <arm_neon.h>
22+
#endif // __ARM_NEON
23+
24+
static void Pool2x2Max() {
25+
// todo impl with neon
26+
}
27+
28+
static void Pool2x2Avg() {
29+
// todo impl with neon
30+
}

0 commit comments

Comments
 (0)