|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 | import unittest
|
| 3 | +import math |
3 | 4 | import numpy as np
|
4 | 5 | import cinn
|
5 | 6 | from cinn import frontend
|
|
10 | 11 | from cinn import common
|
11 | 12 | from cinn.poly import create_stages
|
12 | 13 | import logging
|
13 |
| - |
14 |
| - |
15 |
| -class SingleOpTester(unittest.TestCase): |
16 |
| - ''' |
17 |
| - A unittest framework for testing a single operator. |
18 |
| -
|
19 |
| - Two methods one should override for each Operator's unittest |
20 |
| -
|
21 |
| - 1. create_target_data |
22 |
| - 2. test_op |
23 |
| - ''' |
24 |
| - |
25 |
| - def setUp(self): |
26 |
| - self.counter = 0 |
27 |
| - self.target = common.Target() |
28 |
| - self.target.arch = common.Target.Arch.X86 |
29 |
| - self.target.bits = common.Target.Bit.k32 |
30 |
| - self.target.os = common.Target.OS.Linux |
31 |
| - |
32 |
| - def create_target_data(self, inputs_data): |
33 |
| - ''' |
34 |
| - create the target of the operator's execution output. |
35 |
| - ''' |
36 |
| - raise NotImplemented |
37 |
| - |
38 |
| - def test_op(self): |
39 |
| - ''' |
40 |
| - USER API |
41 |
| -
|
42 |
| - The real use case should implement this method! |
43 |
| - ''' |
44 |
| - pass |
45 |
| - |
46 |
| - def to_test_op(self, input_shapes, output_shape, op_name, attrs): |
47 |
| - ''' |
48 |
| - Test the operator. |
49 |
| - ''' |
50 |
| - self.compiler = cinn.Compiler.create(self.target) |
51 |
| - inputs = [] |
52 |
| - inputs_data = [] |
53 |
| - |
54 |
| - for i_shape in input_shapes: |
55 |
| - expr_shape = [] |
56 |
| - inputs_data.append( |
57 |
| - np.around(np.random.random(i_shape).astype("float32"), 3)) |
58 |
| - |
59 |
| - for dim_shape in i_shape: |
60 |
| - expr_shape.append(ir.Expr(dim_shape)) |
61 |
| - |
62 |
| - inputs.append( |
63 |
| - lang.Placeholder("float32", self.__gen_var_name(), |
64 |
| - expr_shape).to_tensor()) |
65 |
| - module = self.__codegen(op_name, inputs, attrs) |
66 |
| - self.compiler.build(module) |
67 |
| - fn = self.compiler.lookup(op_name) |
68 |
| - out = runtime.cinn_buffer_t( |
69 |
| - np.zeros(output_shape).astype("float32"), runtime.cinn_x86_device) |
70 |
| - |
71 |
| - args = [] |
72 |
| - temp_inputs = [] |
73 |
| - for in_data in inputs_data: |
74 |
| - temp_inputs.append( |
75 |
| - runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device)) |
76 |
| - for in_data in temp_inputs: |
77 |
| - args.append(runtime.cinn_pod_value_t(in_data)) |
78 |
| - |
79 |
| - args.append(runtime.cinn_pod_value_t(out)) |
80 |
| - |
81 |
| - fn(args) |
82 |
| - self.assertTrue( |
83 |
| - np.allclose( |
84 |
| - out.numpy(), self.create_target_data(inputs_data), atol=1e-4)) |
85 |
| - |
86 |
| - def __codegen(self, op_name, inputs, attrs): |
87 |
| - types = [common.Float(32)] |
88 |
| - strategy_map = framework.Operator.get_op_attrs("CINNStrategy") |
89 |
| - res = strategy_map.apply_strategy(op_name, attrs, inputs, types, |
90 |
| - self.target) |
91 |
| - stages = create_stages(res) |
92 |
| - func = lang.lower(op_name, stages, res) |
93 |
| - logging.warning('func:\n\n%s\n', func) |
94 |
| - builder = lang.Module.Builder(op_name, self.target) |
95 |
| - builder.add_function(func) |
96 |
| - return builder.build() |
97 |
| - |
98 |
| - def __gen_var_name(self): |
99 |
| - self.counter = self.counter + 1 |
100 |
| - return "Var_" + str(self.counter) |
| 14 | +from test_utils import SingleOpTester |
101 | 15 |
|
102 | 16 |
|
103 | 17 | class OpTest_add_0(SingleOpTester):
|
104 | 18 | def create_target_data(self, inputs_data):
|
105 |
| - X, Y = inputs_data |
| 19 | + [X, Y] = inputs_data |
106 | 20 | return X + Y
|
107 | 21 |
|
108 | 22 | def test_op(self):
|
109 | 23 | attrs = framework.NodeAttr()
|
110 | 24 | attrs.attr_store = {"axis": 0}
|
111 |
| - self.to_test_op([[100, 32], [100, 32]], [100, 32], "elementwise_add", |
| 25 | + self.to_test_op([[100, 32], [100, 32]], [[100, 32]], "elementwise_add", |
112 | 26 | attrs)
|
113 | 27 |
|
114 | 28 |
|
115 | 29 | class OpTest_add_1(SingleOpTester):
|
116 | 30 | def create_target_data(self, inputs_data):
|
117 |
| - X, Y = inputs_data |
| 31 | + [X, Y] = inputs_data |
118 | 32 | return X + Y
|
119 | 33 |
|
120 | 34 | def test_op(self):
|
121 | 35 | attrs = framework.NodeAttr()
|
122 | 36 | attrs.attr_store = {"axis": 1}
|
123 |
| - self.to_test_op([[3, 2], [2]], [3, 2], "elementwise_add", attrs) |
| 37 | + self.to_test_op([[3, 2], [2]], [[3, 2]], "elementwise_add", attrs) |
124 | 38 |
|
125 | 39 |
|
126 | 40 | if __name__ == "__main__":
|
|
0 commit comments