Skip to content

[DimExpr] DimExpr support print #60146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,64 @@ bool DimExpr::operator!=(const DimExpr& other) const {
return !(*this == other);
}

namespace {

std::string ToTxtStringImpl(std::int64_t dim_expr) {
return std::to_string(dim_expr);
}

std::string ToTxtStringImpl(const std::string& dim_expr) { return dim_expr; }

std::string ToTxtStringImpl(const Negative<DimExpr>& dim_expr) {
return "-" + ToTxtString(dim_expr->data);
}

std::string ToTxtStringImpl(const Reciprocal<DimExpr>& dim_expr) {
return "1 / (" + ToTxtString(dim_expr->data) + ")";
}

std::string ListDimExprToTxtString(const List<DimExpr>& dim_exprs,
const std::string& delim = ", ") {
std::string ret;
for (std::size_t i = 0; i < dim_exprs->size(); ++i) {
if (i > 0) {
ret += delim;
}
ret += ToTxtString(dim_exprs->at(i));
}
return ret;
}

std::string ToTxtStringImpl(const Add<DimExpr>& dim_expr) {
return "Add(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")";
}

std::string ToTxtStringImpl(const Mul<DimExpr>& dim_expr) {
return "Mul(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")";
}

std::string ToTxtStringImpl(const Max<DimExpr>& dim_expr) {
return "Max(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")";
}

std::string ToTxtStringImpl(const Min<DimExpr>& dim_expr) {
return "Min(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")";
}

std::string ToTxtStringImpl(const Broadcast<DimExpr>& dim_expr) {
return "Broadcast(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")";
}

} // namespace

std::string ToTxtString(const DimExpr& dim_expr) {
return std::visit([](const auto& impl) { return ToTxtStringImpl(impl); },
dim_expr.variant());
}

std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) {
stream << ToTxtString(dim_expr);
return stream;
}

} // namespace symbol
5 changes: 5 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <functional>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -239,4 +240,8 @@ class ValueShape {

using ValueShapeDimExprs = ValueShape<DimExpr>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方不该应再叫ValueShape了吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名修改完毕:#60056


std::string ToTxtString(const DimExpr& dim_expr);

std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr);

} // namespace symbol
13 changes: 13 additions & 0 deletions test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,17 @@ TEST(DimExpr, equal) {
builder.Broadcast(DimExpr("S0"), constant1));
}

TEST(DimExpr, print) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
ASSERT_EQ((ToTxtString(sym0 + sym1)), "Add(S0, S1)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名是否有点冗长,ToString就可以吧,或者PIR里其实都用的是Print,我觉得可以保持一致

Copy link
Contributor Author

@jiahy0825 jiahy0825 Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有道理,本函数的返回值是 std::string,因此改成了 ToString
DimExpr 已经重载了 operator<< 函数,因此打印时还可以直接调用 VLOG,例如 VLOG(1)<<dim_expr

ASSERT_EQ((ToTxtString(sym0 - sym1)), "Add(S0, -S1)");
ASSERT_EQ((ToTxtString(sym0 * sym1)), "Mul(S0, S1)");
ASSERT_EQ((ToTxtString(sym0 / sym1)), "Mul(S0, 1 / (S1))");
ASSERT_EQ((ToTxtString(builder.Max(sym0, sym1))), "Max(S0, S1)");
ASSERT_EQ((ToTxtString(builder.Min(sym0, sym1))), "Min(S0, S1)");
ASSERT_EQ((ToTxtString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)");
}

} // namespace symbol::test