Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit e490391

Browse files
committed
fix pybind, ComparisionKind enum export
1 parent 70b04f0 commit e490391

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

cinn/hlir/op/reduction.cc

+10-9
Original file line numberDiff line numberDiff line change
@@ -465,17 +465,18 @@ std::vector<shape_t> InferShapeForReduction(const std::vector<shape_t> &inputs_s
465465
if (attrs.find("keep_dim") != attrs.end()) {
466466
keep_dim = absl::get<bool>(attrs.at("keep_dim"));
467467
}
468-
CHECK(!dim.empty()) << "should have reduce dim, please check!";
469-
CHECK_LE(dim.size(), inputs_shape[0].size()) << "reduce dim should no more than the input size";
470468
std::vector<int> out_shapes;
471-
auto ndim = inputs_shape[0].size();
472-
for (size_t i = 0; i < ndim; ++i) {
473-
if (std::find(dim.begin(), dim.end(), i) != dim.end()) {
474-
if (keep_dim) {
475-
out_shapes.push_back(1);
469+
if (!dim.empty()) {
470+
CHECK_LE(dim.size(), inputs_shape[0].size()) << "reduce dim should no more than the input size";
471+
auto ndim = inputs_shape[0].size();
472+
for (size_t i = 0; i < ndim; ++i) {
473+
if (std::find(dim.begin(), dim.end(), i) != dim.end()) {
474+
if (keep_dim) {
475+
out_shapes.push_back(1);
476+
}
477+
} else {
478+
out_shapes.push_back(inputs_shape[0][i]);
476479
}
477-
} else {
478-
out_shapes.push_back(inputs_shape[0][i]);
479480
}
480481
}
481482

cinn/pybind/frontend.cc

100755100644
+21-3
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,24 @@ void BindFrontend(pybind11::module *m) {
431431
py::arg("padding_algorithm") = "EXPLICIT")
432432
.def("sum", &NetBuilder::sum, py::arg("inputs"));
433433

434+
py::enum_<ComparisonKind>(*m, "ComparisonKind")
435+
.value("kUnk", ComparisonKind::kUnk)
436+
.value("kEq", ComparisonKind::kEq)
437+
.value("kNe", ComparisonKind::kNe)
438+
.value("kGe", ComparisonKind::kGe)
439+
.value("kGt", ComparisonKind::kGt)
440+
.value("kLe", ComparisonKind::kLe)
441+
.value("kLt", ComparisonKind::kLt)
442+
.export_values();
443+
444+
py::enum_<ReduceKind>(*m, "ReduceKind")
445+
.value("kUnk", ReduceKind::kUnk)
446+
.value("kSum", ReduceKind::kSum)
447+
.value("kProd", ReduceKind::kProd)
448+
.value("kMax", ReduceKind::kMax)
449+
.value("kMin", ReduceKind::kMin)
450+
.export_values();
451+
434452
py::class_<CinnBuilder, BaseBuilder>(*m, "CinnBuilder")
435453
.def(py::init<const std::string &>(), py::arg("name") = "")
436454
.def("const_scalar", &CinnBuilder::ConstScalar<bool>)
@@ -455,12 +473,12 @@ void BindFrontend(pybind11::module *m) {
455473
py::arg("data_format") = "NCHW",
456474
py::arg("padding_algorithm") = "EXPLICIT",
457475
py::arg("output_shape") = std::vector<int>{})
458-
.def("compare", &CinnBuilder::Compare, py::arg("lhs"), py::arg("rhs"), py::arg("kind"))
476+
.def("compare", &CinnBuilder::Compare, py::arg("lhs"), py::arg("rhs"), py::arg("kind") = ComparisonKind::kEq)
459477
.def("reduce",
460478
&CinnBuilder::Reduce,
461479
py::arg("operand"),
462-
py::arg("kind"),
463-
py::arg("dim"),
480+
py::arg("kind") = ReduceKind::kSum,
481+
py::arg("dim") = std::vector<int>{},
464482
py::arg("keep_dim") = false)
465483
.def("broadcast_to",
466484
&CinnBuilder::BroadcastTo,

python/tests/test_cinnbuilder.py

+26
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@
3232
enable_gpu = sys.argv.pop()
3333

3434

35+
class TestCinnBuildBasic(unittest.TestCase):
36+
def setUp(self):
37+
pass
38+
39+
def test_compare(self):
40+
builder = CinnBuilder("test_compare")
41+
a = builder.create_input(Float(32), (1, 24, 56, 56), "A")
42+
b = builder.create_input(Float(32), (1, 24, 56, 56), "B")
43+
# default compare kind is ComparisonKind.kEq
44+
c = builder.compare(a, b)
45+
d = builder.compare(a, c, ComparisonKind.kNe)
46+
prog = builder.build()
47+
for i in range(prog.size()):
48+
print(prog[i])
49+
50+
def test_reduce(self):
51+
builder = CinnBuilder("test_compare")
52+
a = builder.create_input(Float(32), (1, 24, 56, 56), "A")
53+
b = builder.reduce(a)
54+
c = builder.reduce(a, ReduceKind.kMax)
55+
d = builder.add(b, c)
56+
prog = builder.build()
57+
for i in range(prog.size()):
58+
print(prog[i])
59+
60+
3561
class TestCinnBuilder(unittest.TestCase):
3662
def setUp(self):
3763
if enable_gpu == "ON":

0 commit comments

Comments
 (0)