Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 9ba1b7d

Browse files
committed
add PEs(matmul, all relus) and Ops(elementwise_add, relu6, mul) and C++/python tests
1 parent e400a0c commit 9ba1b7d

32 files changed

+976
-176
lines changed

cinn/common/ir_util.h

+8
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,13 @@ Expr make_const(Type t, T v) {
9595
return Expr();
9696
}
9797

98+
template <typename FuncOp>
99+
Expr FoldExpr(FuncOp funcOp, Expr init_value, const std::vector<Expr> &values) {
100+
for (const Expr &val : values) {
101+
init_value = funcOp(init_value, val);
102+
}
103+
return init_value;
104+
}
105+
98106
} // namespace common
99107
} // namespace cinn

cinn/frontend/syntax.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Placeholder::operator Variable() {
2828
}
2929

3030
Variable Program::add(const Variable& a, const Variable& b) {
31-
Instruction instr("add");
31+
Instruction instr("elementwise_add");
3232
instr.SetInputs({a, b});
3333
AddInstruction(instr);
3434
return instr.GetOutputs()[0];

cinn/frontend/syntax.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,14 @@ struct Variable : public common::Shared<_Variable_> {
7676
* Data of a Instruction.
7777
*/
7878
struct _Instruction_ : public common::Object {
79-
using attr_t = std::variant<int, float, std::string, std::vector<int>, std::vector<float>, std::vector<std::string>>;
79+
using attr_t = std::variant<int,
80+
float,
81+
bool,
82+
std::string,
83+
std::vector<int>,
84+
std::vector<float>,
85+
std::vector<bool>,
86+
std::vector<std::string>>;
8087

8188
std::string op_type;
8289
std::unordered_map<std::string, attr_t> attrs;

cinn/hlir/framework/node.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ using NodePtr = std::shared_ptr<Node>;
2323
* and other parameters like axis.
2424
*/
2525
struct NodeAttr {
26-
using attr_t = std::variant<int, float, std::string, std::vector<int>, std::vector<float>, std::vector<std::string>>;
26+
using attr_t = std::variant<int,
27+
float,
28+
bool,
29+
std::string,
30+
std::vector<int>,
31+
std::vector<float>,
32+
std::vector<bool>,
33+
std::vector<std::string>>;
2734

2835
/**
2936
* \brief The operator this node uses.
@@ -90,7 +97,14 @@ class Node : public common::GraphNode {
9097
* \brief NodeData represents the output data from an operator.
9198
*/
9299
class NodeData : public common::GraphNode {
93-
using attr_t = std::variant<int, float, std::string, std::vector<int>, std::vector<float>, std::vector<std::string>>;
100+
using attr_t = std::variant<int,
101+
float,
102+
bool,
103+
std::string,
104+
std::vector<int>,
105+
std::vector<float>,
106+
std::vector<bool>,
107+
std::vector<std::string>>;
94108

95109
public:
96110
NodeData(NodePtr node, uint32_t index, uint32_t version, std::string id)

cinn/hlir/framework/op_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace framework {
1818
using CCompute = std::function<std::shared_ptr<ir::Tensor>(const std::vector<ir::Tensor>)>;
1919

2020
TEST(Operator, GetAttrs) {
21-
auto add = Operator::Get("add");
21+
auto add = Operator::Get("elementwise_add");
2222
Operator temp = *add;
2323
auto strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
2424

@@ -46,7 +46,7 @@ TEST(Operator, GetAttrs) {
4646
auto func = Lower("add1", rets.back(), inputs);
4747
LOG(INFO) << "Test Strategy Codegen:\n" << func;
4848

49-
ASSERT_EQ(impl->name, "strategy.add.x86");
49+
ASSERT_EQ(impl->name, "strategy.elementwise_add.x86");
5050
ASSERT_EQ(add->description, "Add two tensors");
5151
}
5252

cinn/hlir/framework/print_graph_pass_test.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ TEST(Operator, GetAttrs) {
5050
ApplyPass(g, "PrintGraph");
5151
auto s = g->GetAttrs<std::string>("print_graph");
5252
LOG(INFO) << s;
53-
ASSERT_EQ(s, "0:add(add_0)\n1:add(add_1)\n2:add(add_2)\n");
53+
ASSERT_EQ(s,
54+
"0:elementwise_add(elementwise_add_0)\n1:elementwise_add(elementwise_add_1)\n2:elementwise_add(elementwise_"
55+
"add_2)\n");
5456
}
5557

5658
} // namespace framework

cinn/hlir/op/CMakeLists.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
set(srcs
2-
nn.cc
2+
nn.cc
3+
broadcast.cc
4+
transform.cc
35
)
46

57
foreach(cpp ${srcs})
68
set(core_src
79
"${core_src};cinn/hlir/op/${cpp}"
810
CACHE INTERNAL "")
911
endforeach()
12+
13+
cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core)

cinn/hlir/op/broadcast.cc

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "cinn/hlir/pe/broadcast.h"
2+
3+
#include <iostream>
4+
#include "cinn/hlir/framework/node.h"
5+
#include "cinn/hlir/framework/op.h"
6+
#include "cinn/hlir/framework/op_strategy.h"
7+
8+
namespace cinn {
9+
namespace hlir {
10+
namespace op {
11+
using common::_CINNValuePack_;
12+
using common::CINNValue;
13+
using common::CINNValuePack;
14+
using framework::OpStrategy;
15+
using framework::StrategyFunction;
16+
17+
std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr &attrs,
18+
const std::vector<ir::Tensor> &inputs,
19+
const std::vector<Type> &out_type,
20+
const Target &target) {
21+
framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) {
22+
CINNValuePack a = args[0];
23+
ir::Expr A_expr = a[0];
24+
ir::Expr B_expr = a[1];
25+
CHECK(A_expr.as_tensor());
26+
CHECK(B_expr.as_tensor());
27+
ir::Tensor A = A_expr.as_tensor_ref();
28+
ir::Tensor B = B_expr.as_tensor_ref();
29+
auto attr_store = attrs.attr_store;
30+
auto iter = attr_store.find("axis");
31+
ir::Expr axis;
32+
if (iter != attr_store.end()) {
33+
axis = ir::Expr(std::get<int>(iter->second));
34+
}
35+
36+
auto out = pe::Add(A, B, UniqName("C"), axis);
37+
38+
auto stages = CreateStages({out});
39+
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
40+
});
41+
42+
framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
43+
CINNValuePack arg_pack = args[0];
44+
ir::Expr A [[maybe_unused]] = arg_pack[0];
45+
CHECK_EQ(arg_pack.size(), 2UL);
46+
*ret = arg_pack;
47+
});
48+
49+
auto strategy = std::make_shared<framework::OpStrategy>();
50+
strategy->AddImpl(add_compute, add_schedule, "strategy.elementwise_add.x86", 1);
51+
52+
return strategy;
53+
}
54+
55+
std::vector<std::vector<int>> InferShapeForElementwiseAdd(const std::vector<std::vector<int>> &inputs_shape,
56+
const framework::NodeAttr &attrs) {
57+
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
58+
std::vector<std::vector<int>> res{inputs_shape[0]};
59+
return res;
60+
}
61+
62+
std::vector<Type> InferDtypeForElementwiseAdd(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
63+
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
64+
std::vector<Type> res{inputs_type[0]};
65+
return res;
66+
}
67+
68+
} // namespace op
69+
} // namespace hlir
70+
} // namespace cinn
71+
72+
CINN_REGISTER_HELPER(broadcast_ops) {
73+
CINN_REGISTER_OP(elementwise_add)
74+
.describe("Add two tensors")
75+
.set_num_inputs(2)
76+
.set_num_outputs(1)
77+
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForElementwiseAdd)
78+
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwiseAdd))
79+
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwiseAdd))
80+
.set_support_level(4);
81+
}

cinn/hlir/op/nn.cc

+32-44
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,58 @@ using common::CINNValuePack;
1313
using framework::OpStrategy;
1414
using framework::StrategyFunction;
1515

16-
std::shared_ptr<OpStrategy> StrategyForAdd(const framework::NodeAttr &attrs,
17-
const std::vector<ir::Tensor> &inputs,
18-
const std::vector<Type> &out_type,
19-
const Target &target) {
20-
framework::CINNCompute add_compute([](lang::Args args, lang::RetValue *ret) {
16+
std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
17+
const std::vector<ir::Tensor> &inputs,
18+
const std::vector<Type> &out_type,
19+
const Target &target) {
20+
framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) {
2121
CINNValuePack a = args[0];
2222
ir::Expr A = a[0];
23-
ir::Expr B = a[1];
2423
CHECK(A.as_tensor());
25-
CHECK(B.as_tensor());
26-
auto out = pe::Add(A.as_tensor_ref(), B.as_tensor_ref(), UniqName("C"));
27-
24+
auto out = pe::Relu<float>(A.as_tensor_ref(), 0.0, UniqName("Relu_output"));
2825
auto stages = CreateStages({out});
2926
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
3027
});
3128

32-
framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
29+
framework::CINNSchedule relu_schedule([](lang::Args args, lang::RetValue *ret) {
3330
CINNValuePack arg_pack = args[0];
3431
ir::Expr A [[maybe_unused]] = arg_pack[0];
3532
CHECK_EQ(arg_pack.size(), 2UL);
3633
*ret = arg_pack;
3734
});
3835

3936
auto strategy = std::make_shared<framework::OpStrategy>();
40-
strategy->AddImpl(add_compute, add_schedule, "strategy.add.x86", 1);
41-
37+
CHECK(out_type.size()) << "Out_type of relu op is empty! Please check.";
38+
if (out_type[0] == Float(32)) {
39+
strategy->AddImpl(relu_compute, relu_schedule, "strategy.relu.x86", 1);
40+
} else {
41+
LOG(INFO) << "Relu op with dtype != float32 is not implemented yet!";
42+
}
4243
return strategy;
4344
}
4445

45-
std::vector<std::vector<int>> InferShapeForAdd(const std::vector<std::vector<int>> &inputs_shape,
46-
const framework::NodeAttr &attrs) {
46+
std::vector<std::vector<int>> InferShapeForRelu(const std::vector<std::vector<int>> &inputs_shape,
47+
const framework::NodeAttr &attrs) {
4748
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
4849
std::vector<std::vector<int>> res{inputs_shape[0]};
4950
return res;
5051
}
5152

52-
std::vector<Type> InferDtypeForAdd(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
53+
std::vector<Type> InferDtypeForRelu(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
5354
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
5455
std::vector<Type> res{inputs_type[0]};
5556
return res;
5657
}
5758

58-
std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
59-
const std::vector<ir::Tensor> &inputs,
60-
const std::vector<Type> &out_type,
61-
const Target &target) {
59+
std::shared_ptr<OpStrategy> StrategyForRelu6(const framework::NodeAttr &attrs,
60+
const std::vector<ir::Tensor> &inputs,
61+
const std::vector<Type> &out_type,
62+
const Target &target) {
6263
framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) {
6364
CINNValuePack a = args[0];
6465
ir::Expr A = a[0];
6566
CHECK(A.as_tensor());
66-
auto out = pe::Relu<float>(A.as_tensor_ref(), 0.0, UniqName("Relu_output"));
67+
auto out = pe::Relu6<float>(A.as_tensor_ref(), 0.0, UniqName("Relu6_output"));
6768
auto stages = CreateStages({out});
6869
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
6970
});
@@ -76,28 +77,15 @@ std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
7677
});
7778

7879
auto strategy = std::make_shared<framework::OpStrategy>();
79-
CHECK(out_type.size()) << "Out_type of relu op is empty! Please check.";
80+
CHECK(out_type.size()) << "Out_type of relu6 op is empty! Please check.";
8081
if (out_type[0] == Float(32)) {
81-
strategy->AddImpl(relu_compute, relu_schedule, "strategy.relu.x86", 1);
82+
strategy->AddImpl(relu_compute, relu_schedule, "strategy.relu6.x86", 1);
8283
} else {
83-
LOG(INFO) << "Relu op with dtype != float32 is not implemented yet!";
84+
LOG(INFO) << "Relu6 op with dtype != float32 is not implemented yet!";
8485
}
8586
return strategy;
8687
}
8788

88-
std::vector<std::vector<int>> InferShapeForRelu(const std::vector<std::vector<int>> &inputs_shape,
89-
const framework::NodeAttr &attrs) {
90-
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
91-
std::vector<std::vector<int>> res{inputs_shape[0]};
92-
return res;
93-
}
94-
95-
std::vector<Type> InferDtypeForRelu(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
96-
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
97-
std::vector<Type> res{inputs_type[0]};
98-
return res;
99-
}
100-
10189
std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
10290
const std::vector<ir::Tensor> &inputs,
10391
const std::vector<Type> &out_type,
@@ -245,14 +233,6 @@ std::vector<Type> InferDtypeForBatchNorm(const std::vector<Type> &inputs_type, c
245233
} // namespace cinn
246234

247235
CINN_REGISTER_HELPER(nn_ops) {
248-
CINN_REGISTER_OP(add)
249-
.describe("Add two tensors")
250-
.set_num_inputs(2)
251-
.set_num_outputs(1)
252-
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForAdd)
253-
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForAdd))
254-
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForAdd))
255-
.set_support_level(4);
256236
CINN_REGISTER_OP(relu)
257237
.describe("Output 0 for each input element < 0. Output itself for each input element >= 0.")
258238
.set_num_inputs(1)
@@ -261,6 +241,14 @@ CINN_REGISTER_HELPER(nn_ops) {
261241
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForRelu))
262242
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForRelu))
263243
.set_support_level(4);
244+
CINN_REGISTER_OP(relu6)
245+
.describe("Output 0 for each input element < 0. Output itself for each input element >= 0 and <=6.")
246+
.set_num_inputs(1)
247+
.set_num_outputs(1)
248+
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForRelu6)
249+
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForRelu))
250+
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForRelu))
251+
.set_support_level(4);
264252
CINN_REGISTER_OP(conv2d)
265253
.describe("Do a 2-D convolution with an NCHW-layout.")
266254
.set_num_inputs(2) // here we consider filter as anohter input

0 commit comments

Comments
 (0)