@@ -13,57 +13,58 @@ using common::CINNValuePack;
13
13
using framework::OpStrategy;
14
14
using framework::StrategyFunction;
15
15
16
- std::shared_ptr<OpStrategy> StrategyForAdd (const framework::NodeAttr &attrs,
17
- const std::vector<ir::Tensor> &inputs,
18
- const std::vector<Type> &out_type,
19
- const Target &target) {
20
- framework::CINNCompute add_compute ([](lang::Args args, lang::RetValue *ret) {
16
+ std::shared_ptr<OpStrategy> StrategyForRelu (const framework::NodeAttr &attrs,
17
+ const std::vector<ir::Tensor> &inputs,
18
+ const std::vector<Type> &out_type,
19
+ const Target &target) {
20
+ framework::CINNCompute relu_compute ([](lang::Args args, lang::RetValue *ret) {
21
21
CINNValuePack a = args[0 ];
22
22
ir::Expr A = a[0 ];
23
- ir::Expr B = a[1 ];
24
23
CHECK (A.as_tensor ());
25
- CHECK (B.as_tensor ());
26
- auto out = pe::Add (A.as_tensor_ref (), B.as_tensor_ref (), UniqName (" C" ));
27
-
24
+ auto out = pe::Relu<float >(A.as_tensor_ref (), 0.0 , UniqName (" Relu_output" ));
28
25
auto stages = CreateStages ({out});
29
26
*ret = CINNValuePack{{CINNValue (ir::Expr (out.get ())), CINNValue (stages)}};
30
27
});
31
28
32
- framework::CINNSchedule add_schedule ([](lang::Args args, lang::RetValue *ret) {
29
+ framework::CINNSchedule relu_schedule ([](lang::Args args, lang::RetValue *ret) {
33
30
CINNValuePack arg_pack = args[0 ];
34
31
ir::Expr A [[maybe_unused]] = arg_pack[0 ];
35
32
CHECK_EQ (arg_pack.size (), 2UL );
36
33
*ret = arg_pack;
37
34
});
38
35
39
36
auto strategy = std::make_shared<framework::OpStrategy>();
40
- strategy->AddImpl (add_compute, add_schedule, " strategy.add.x86" , 1 );
41
-
37
+ CHECK (out_type.size ()) << " Out_type of relu op is empty! Please check." ;
38
+ if (out_type[0 ] == Float (32 )) {
39
+ strategy->AddImpl (relu_compute, relu_schedule, " strategy.relu.x86" , 1 );
40
+ } else {
41
+ LOG (INFO) << " Relu op with dtype != float32 is not implemented yet!" ;
42
+ }
42
43
return strategy;
43
44
}
44
45
45
- std::vector<std::vector<int >> InferShapeForAdd (const std::vector<std::vector<int >> &inputs_shape,
46
- const framework::NodeAttr &attrs) {
46
+ std::vector<std::vector<int >> InferShapeForRelu (const std::vector<std::vector<int >> &inputs_shape,
47
+ const framework::NodeAttr &attrs) {
47
48
CHECK (!inputs_shape.empty () && !inputs_shape[0 ].empty ()) << " The input's shape size is 0! Please check again." ;
48
49
std::vector<std::vector<int >> res{inputs_shape[0 ]};
49
50
return res;
50
51
}
51
52
52
- std::vector<Type> InferDtypeForAdd (const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
53
+ std::vector<Type> InferDtypeForRelu (const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
53
54
CHECK (!inputs_type.empty ()) << " The input's type size is 0! Please check again." ;
54
55
std::vector<Type> res{inputs_type[0 ]};
55
56
return res;
56
57
}
57
58
58
- std::shared_ptr<OpStrategy> StrategyForRelu (const framework::NodeAttr &attrs,
59
- const std::vector<ir::Tensor> &inputs,
60
- const std::vector<Type> &out_type,
61
- const Target &target) {
59
+ std::shared_ptr<OpStrategy> StrategyForRelu6 (const framework::NodeAttr &attrs,
60
+ const std::vector<ir::Tensor> &inputs,
61
+ const std::vector<Type> &out_type,
62
+ const Target &target) {
62
63
framework::CINNCompute relu_compute ([](lang::Args args, lang::RetValue *ret) {
63
64
CINNValuePack a = args[0 ];
64
65
ir::Expr A = a[0 ];
65
66
CHECK (A.as_tensor ());
66
- auto out = pe::Relu <float >(A.as_tensor_ref (), 0.0 , UniqName (" Relu_output " ));
67
+ auto out = pe::Relu6 <float >(A.as_tensor_ref (), 0.0 , UniqName (" Relu6_output " ));
67
68
auto stages = CreateStages ({out});
68
69
*ret = CINNValuePack{{CINNValue (ir::Expr (out.get ())), CINNValue (stages)}};
69
70
});
@@ -76,28 +77,15 @@ std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
76
77
});
77
78
78
79
auto strategy = std::make_shared<framework::OpStrategy>();
79
- CHECK (out_type.size ()) << " Out_type of relu op is empty! Please check." ;
80
+ CHECK (out_type.size ()) << " Out_type of relu6 op is empty! Please check." ;
80
81
if (out_type[0 ] == Float (32 )) {
81
- strategy->AddImpl (relu_compute, relu_schedule, " strategy.relu .x86" , 1 );
82
+ strategy->AddImpl (relu_compute, relu_schedule, " strategy.relu6 .x86" , 1 );
82
83
} else {
83
- LOG (INFO) << " Relu op with dtype != float32 is not implemented yet!" ;
84
+ LOG (INFO) << " Relu6 op with dtype != float32 is not implemented yet!" ;
84
85
}
85
86
return strategy;
86
87
}
87
88
88
- std::vector<std::vector<int >> InferShapeForRelu (const std::vector<std::vector<int >> &inputs_shape,
89
- const framework::NodeAttr &attrs) {
90
- CHECK (!inputs_shape.empty () && !inputs_shape[0 ].empty ()) << " The input's shape size is 0! Please check again." ;
91
- std::vector<std::vector<int >> res{inputs_shape[0 ]};
92
- return res;
93
- }
94
-
95
- std::vector<Type> InferDtypeForRelu (const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
96
- CHECK (!inputs_type.empty ()) << " The input's type size is 0! Please check again." ;
97
- std::vector<Type> res{inputs_type[0 ]};
98
- return res;
99
- }
100
-
101
89
std::shared_ptr<OpStrategy> StrategyForConv2d (const framework::NodeAttr &attrs,
102
90
const std::vector<ir::Tensor> &inputs,
103
91
const std::vector<Type> &out_type,
@@ -245,14 +233,6 @@ std::vector<Type> InferDtypeForBatchNorm(const std::vector<Type> &inputs_type, c
245
233
} // namespace cinn
246
234
247
235
CINN_REGISTER_HELPER (nn_ops) {
248
- CINN_REGISTER_OP (add)
249
- .describe (" Add two tensors" )
250
- .set_num_inputs (2 )
251
- .set_num_outputs (1 )
252
- .set_attr <cinn::hlir::framework::StrategyFunction>(" CINNStrategy" , cinn::hlir::op::StrategyForAdd)
253
- .set_attr (" infershape" , std::function (cinn::hlir::op::InferShapeForAdd))
254
- .set_attr (" inferdtype" , std::function (cinn::hlir::op::InferDtypeForAdd))
255
- .set_support_level (4 );
256
236
CINN_REGISTER_OP (relu)
257
237
.describe (" Output 0 for each input element < 0. Output itself for each input element >= 0." )
258
238
.set_num_inputs (1 )
@@ -261,6 +241,14 @@ CINN_REGISTER_HELPER(nn_ops) {
261
241
.set_attr (" infershape" , std::function (cinn::hlir::op::InferShapeForRelu))
262
242
.set_attr (" inferdtype" , std::function (cinn::hlir::op::InferDtypeForRelu))
263
243
.set_support_level (4 );
244
+ CINN_REGISTER_OP (relu6)
245
+ .describe (" Output 0 for each input element < 0. Output itself for each input element >= 0 and <=6." )
246
+ .set_num_inputs (1 )
247
+ .set_num_outputs (1 )
248
+ .set_attr <cinn::hlir::framework::StrategyFunction>(" CINNStrategy" , cinn::hlir::op::StrategyForRelu6)
249
+ .set_attr (" infershape" , std::function (cinn::hlir::op::InferShapeForRelu))
250
+ .set_attr (" inferdtype" , std::function (cinn::hlir::op::InferDtypeForRelu))
251
+ .set_support_level (4 );
264
252
CINN_REGISTER_OP (conv2d)
265
253
.describe (" Do a 2-D convolution with an NCHW-layout." )
266
254
.set_num_inputs (2 ) // here we consider filter as anohter input
0 commit comments