Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit e46f501

Browse files
committed
modify codes
1 parent 2d414ff commit e46f501

14 files changed

+185
-225
lines changed

cinn/hlir/framework/print_graph_pass_test.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "cinn/hlir/framework/pass.h"
1010
#include "cinn/hlir/op/use_ops.h"
1111
#include "cinn/lang/packed_func.h"
12+
#include "cinn/utils/string.h"
1213

1314
namespace cinn {
1415
namespace hlir {
@@ -50,9 +51,12 @@ TEST(Operator, GetAttrs) {
5051
ApplyPass(g, "PrintGraph");
5152
auto s = g->GetAttrs<std::string>("print_graph");
5253
LOG(INFO) << s;
53-
ASSERT_EQ(s,
54-
"0:elementwise_add(elementwise_add_0)\n1:elementwise_add(elementwise_add_1)\n2:elementwise_add(elementwise_"
55-
"add_2)\n");
54+
std::string target_str = R"ROC(
55+
0:elementwise_add(elementwise_add_0)
56+
1:elementwise_add(elementwise_add_1)
57+
2:elementwise_add(elementwise_add_2)
58+
)ROC";
59+
ASSERT_EQ(utils::Trim(s), utils::Trim(target_str));
5660
}
5761

5862
} // namespace framework

cinn/hlir/op/nn.cc

+3
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ CINN_REGISTER_HELPER(nn_ops) {
241241
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForRelu))
242242
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForRelu))
243243
.set_support_level(4);
244+
244245
CINN_REGISTER_OP(relu6)
245246
.describe("Output 0 for each input element < 0. Output itself for each input element >= 0 and <=6.")
246247
.set_num_inputs(1)
@@ -249,6 +250,7 @@ CINN_REGISTER_HELPER(nn_ops) {
249250
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForRelu))
250251
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForRelu))
251252
.set_support_level(4);
253+
252254
CINN_REGISTER_OP(conv2d)
253255
.describe("Do a 2-D convolution with an NCHW-layout.")
254256
.set_num_inputs(2) // here we consider filter as anohter input
@@ -257,6 +259,7 @@ CINN_REGISTER_HELPER(nn_ops) {
257259
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForConv2d))
258260
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForConv2d))
259261
.set_support_level(4);
262+
260263
CINN_REGISTER_OP(batchnorm)
261264
.describe("Can be used as a normalizer function for convolution or fully_connected operations.")
262265
.set_num_inputs(2) // here we consider batchnorm's 4 attrs(mean, variance, scale, bias) as another input

cinn/hlir/op/op_broadcast_test.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
2828
NodeAttr attrs;
2929
std::vector<ir::Tensor> inputs{A.tensor(), B.tensor()};
3030
std::vector<Type> type{Float(32)};
31-
common::Target target;
32-
target.arch = common::Target::Arch::X86;
31+
common::Target target = common::DefaultHostTarget();
3332
auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, target));
3433
common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}};
3534
common::CINNValuePack rets = impl->fcompute(cinn_input);
@@ -61,8 +60,7 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
6160
attrs.attr_store["axis"] = 1;
6261
std::vector<ir::Tensor> inputs{A.tensor(), B.tensor()};
6362
std::vector<Type> type{Float(32)};
64-
common::Target target;
65-
target.arch = common::Target::Arch::X86;
63+
common::Target target = common::DefaultHostTarget();
6664
auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, target));
6765
common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}};
6866
common::CINNValuePack rets = impl->fcompute(cinn_input);

cinn/hlir/pe/broadcast.cc

+5-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ void GetBroadcastShape(const std::vector<Expr>& shape1,
2121
std::vector<bool>* broadcast_flag1,
2222
std::vector<bool>* broadcast_flag2,
2323
const Expr& axis) {
24-
CHECK(common_shape && broadcast_flag1 && broadcast_flag2);
24+
CHECK(common_shape);
25+
CHECK(broadcast_flag1);
26+
CHECK(broadcast_flag2);
2527
std::vector<Expr> shape2_new = shape2;
2628
if (axis.defined()) {
2729
int axis_val = axis.as_int32();
@@ -93,7 +95,8 @@ void GetBroadcastIndice(const std::vector<Expr>& indice,
9395
std::vector<Expr>* broadcast_indice2,
9496
const std::vector<bool>& broadcast_flags1,
9597
const std::vector<bool>& broadcast_flags2) {
96-
CHECK(broadcast_indice1 && broadcast_indice2);
98+
CHECK(broadcast_indice1);
99+
CHECK(broadcast_indice2);
97100
if (broadcast_indice1->empty() && broadcast_indice2->empty()) {
98101
int flag_size = broadcast_flags1.size();
99102
int i;

cinn/hlir/pe/pe_broadcast_test.cc

+1-4
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ void TestBroadcastPE(
2626

2727
auto stages = CreateStages({C});
2828

29-
Target target;
30-
target.arch = Target::Arch ::X86;
31-
target.bits = Target::Bit ::k32;
32-
target.os = Target::OS ::Linux;
29+
Target target = common::DefaultHostTarget();
3330
Module::Builder builder("module0", target);
3431
auto func = Lower("fn", stages, {A, B, C});
3532
builder.AddFunction(func);

cinn/hlir/pe/pe_elementwise_test.cc

+1-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ void TestElementwisePE(const std::string &fn_name,
2727

2828
auto stages = CreateStages({A_out});
2929

30-
Target target;
31-
target.arch = Target::Arch ::X86;
32-
target.bits = Target::Bit ::k32;
33-
target.os = Target::OS ::Linux;
30+
Target target = common::DefaultHostTarget();
3431
Module::Builder builder("module0", target);
3532
auto func = Lower("fn", stages, {A, A_out});
3633
LOG(INFO) << "func:\n" << func;

cinn/hlir/pe/pe_transform_test.cc

+1-4
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ TEST(MatmulPE, PE_Matmul_Test0) {
2525

2626
auto stages = CreateStages({C});
2727

28-
Target target;
29-
target.arch = Target::Arch ::X86;
30-
target.bits = Target::Bit ::k32;
31-
target.os = Target::OS ::Linux;
28+
Target target = common::DefaultHostTarget();
3229
Module::Builder builder("module0", target);
3330
auto func = Lower("fn", stages, {A, B, C});
3431
builder.AddFunction(func);

cinn/hlir/pe/reduction.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ std::vector<Tensor> DoReduce(const Tensor& tensor,
107107
const std::string& output_name) {
108108
std::vector<Var> reduce_axes;
109109
for (auto& axis : real_axes) {
110-
std::string name = "k" + std::to_string(axis);
111-
reduce_axes.push_back(_Var_::Make(Expr(0), tensor->shape[axis], name));
110+
std::string name = UniqName("k");
111+
reduce_axes.push_back(Var(tensor->shape[axis], name));
112112
}
113113
auto compute = [&](const std::vector<Expr>& indices) {
114114
std::vector<Expr> eval_indice;

cinn/hlir/pe/transform.cc

+20-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void GetMatmulOutputShape(const std::vector<Expr>& shape1,
3838
void GetMatmulIndice(const std::vector<Expr>& shape1_new,
3939
const std::vector<Expr>& shape2_new,
4040
const std::vector<Expr>& indices,
41+
bool trans_a,
42+
bool trans_b,
4143
int x_num_col_dims,
4244
int y_num_col_dims,
4345
std::vector<Expr>* indice1,
@@ -54,8 +56,8 @@ void GetMatmulIndice(const std::vector<Expr>& shape1_new,
5456
// A reduce axes
5557
for (size_t i = x_num_col_dims; i < shape1_new.size(); i++) {
5658
reduce_shape1 = reduce_shape1 * shape1_new[i];
57-
std::string reduce_name = "k" + std::to_string(count);
58-
auto k = _Var_::Make(Expr(0), shape1_new[i], reduce_name);
59+
std::string reduce_name = UniqName("k");
60+
auto k = Var(shape1_new[i], reduce_name);
5961
reduce_axes->emplace_back(k);
6062
indice1->emplace_back(k);
6163
count++;
@@ -73,6 +75,12 @@ void GetMatmulIndice(const std::vector<Expr>& shape1_new,
7375
for (size_t i = y_num_col_dims; i < shape2_new.size(); i++) {
7476
indice2->emplace_back(indices[x_num_col_dims + i - y_num_col_dims]);
7577
}
78+
if (trans_a) {
79+
reverse(indice1->begin(), indice1->end());
80+
}
81+
if (trans_b) {
82+
reverse(indice2->begin(), indice2->end());
83+
}
7684
}
7785
}
7886

@@ -92,8 +100,16 @@ Tensor Matmul(const Tensor& A,
92100
GetMatmulOutputShape(
93101
A->shape, B->shape, &shape1_new, &shape2_new, &output_shape, trans_a, trans_b, x_num_col_dims, y_num_col_dims);
94102
auto fn = [&](const std::vector<Expr>& indices) {
95-
GetMatmulIndice(
96-
shape1_new, shape2_new, indices, x_num_col_dims, y_num_col_dims, &A_indice, &B_indice, &reduce_axes);
103+
GetMatmulIndice(shape1_new,
104+
shape2_new,
105+
indices,
106+
trans_a,
107+
trans_b,
108+
x_num_col_dims,
109+
y_num_col_dims,
110+
&A_indice,
111+
&B_indice,
112+
&reduce_axes);
97113
return ReduceSum(A(A_indice) * B(B_indice), Expr());
98114
};
99115
return Compute(output_shape, fn, name, reduce_axes);

python/CMakeLists.txt

-5
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,6 @@ ADD_TEST(NAME test_cinn_pe_transform
4242
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_pe_transform.py
4343
)
4444

45-
ADD_TEST(NAME test_cinn_op_broadcast
46-
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
47-
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_op_broadcast.py
48-
)
49-
5045
ADD_TEST(NAME test_cinn_op_nn
5146
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
5247
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_op_nn.py

python/tests/test_op_broadcast.py

+6-92
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import unittest
3+
import math
34
import numpy as np
45
import cinn
56
from cinn import frontend
@@ -10,117 +11,30 @@
1011
from cinn import common
1112
from cinn.poly import create_stages
1213
import logging
13-
14-
15-
class SingleOpTester(unittest.TestCase):
16-
'''
17-
A unittest framework for testing a single operator.
18-
19-
Two methods one should override for each Operator's unittest
20-
21-
1. create_target_data
22-
2. test_op
23-
'''
24-
25-
def setUp(self):
26-
self.counter = 0
27-
self.target = common.Target()
28-
self.target.arch = common.Target.Arch.X86
29-
self.target.bits = common.Target.Bit.k32
30-
self.target.os = common.Target.OS.Linux
31-
32-
def create_target_data(self, inputs_data):
33-
'''
34-
create the target of the operator's execution output.
35-
'''
36-
raise NotImplemented
37-
38-
def test_op(self):
39-
'''
40-
USER API
41-
42-
The real use case should implement this method!
43-
'''
44-
pass
45-
46-
def to_test_op(self, input_shapes, output_shape, op_name, attrs):
47-
'''
48-
Test the operator.
49-
'''
50-
self.compiler = cinn.Compiler.create(self.target)
51-
inputs = []
52-
inputs_data = []
53-
54-
for i_shape in input_shapes:
55-
expr_shape = []
56-
inputs_data.append(
57-
np.around(np.random.random(i_shape).astype("float32"), 3))
58-
59-
for dim_shape in i_shape:
60-
expr_shape.append(ir.Expr(dim_shape))
61-
62-
inputs.append(
63-
lang.Placeholder("float32", self.__gen_var_name(),
64-
expr_shape).to_tensor())
65-
module = self.__codegen(op_name, inputs, attrs)
66-
self.compiler.build(module)
67-
fn = self.compiler.lookup(op_name)
68-
out = runtime.cinn_buffer_t(
69-
np.zeros(output_shape).astype("float32"), runtime.cinn_x86_device)
70-
71-
args = []
72-
temp_inputs = []
73-
for in_data in inputs_data:
74-
temp_inputs.append(
75-
runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device))
76-
for in_data in temp_inputs:
77-
args.append(runtime.cinn_pod_value_t(in_data))
78-
79-
args.append(runtime.cinn_pod_value_t(out))
80-
81-
fn(args)
82-
self.assertTrue(
83-
np.allclose(
84-
out.numpy(), self.create_target_data(inputs_data), atol=1e-4))
85-
86-
def __codegen(self, op_name, inputs, attrs):
87-
types = [common.Float(32)]
88-
strategy_map = framework.Operator.get_op_attrs("CINNStrategy")
89-
res = strategy_map.apply_strategy(op_name, attrs, inputs, types,
90-
self.target)
91-
stages = create_stages(res)
92-
func = lang.lower(op_name, stages, res)
93-
logging.warning('func:\n\n%s\n', func)
94-
builder = lang.Module.Builder(op_name, self.target)
95-
builder.add_function(func)
96-
return builder.build()
97-
98-
def __gen_var_name(self):
99-
self.counter = self.counter + 1
100-
return "Var_" + str(self.counter)
14+
from test_utils import SingleOpTester
10115

10216

10317
class OpTest_add_0(SingleOpTester):
10418
def create_target_data(self, inputs_data):
105-
X, Y = inputs_data
19+
[X, Y] = inputs_data
10620
return X + Y
10721

10822
def test_op(self):
10923
attrs = framework.NodeAttr()
11024
attrs.attr_store = {"axis": 0}
111-
self.to_test_op([[100, 32], [100, 32]], [100, 32], "elementwise_add",
25+
self.to_test_op([[100, 32], [100, 32]], [[100, 32]], "elementwise_add",
11226
attrs)
11327

11428

11529
class OpTest_add_1(SingleOpTester):
11630
def create_target_data(self, inputs_data):
117-
X, Y = inputs_data
31+
[X, Y] = inputs_data
11832
return X + Y
11933

12034
def test_op(self):
12135
attrs = framework.NodeAttr()
12236
attrs.attr_store = {"axis": 1}
123-
self.to_test_op([[3, 2], [2]], [3, 2], "elementwise_add", attrs)
37+
self.to_test_op([[3, 2], [2]], [[3, 2]], "elementwise_add", attrs)
12438

12539

12640
if __name__ == "__main__":

0 commit comments

Comments
 (0)