Skip to content

Commit 34e9057

Browse files
authored
add indextype (#56112)
IR 的 builtin dialect 中加入 IndexType
1 parent 74eb309 commit 34e9057

13 files changed

+68
-0
lines changed

paddle/cinn/utils/attribute_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ common::Type ConvertIRType(::ir::Type type) {
8787
CASE_TYPE(Int16Type, I16)
8888
CASE_TYPE(Int32Type, I32)
8989
CASE_TYPE(Int64Type, I64)
90+
CASE_TYPE(IndexType, I32)
9091
CASE_TYPE(BoolType, UI1)
9192

9293
LOG(FATAL) << "unknown ir::Type " << type;

paddle/fluid/ir/dialect/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
6666
return phi::DataType::INT32;
6767
} else if (dtype.isa<ir::Int64Type>()) {
6868
return phi::DataType::INT64;
69+
} else if (dtype.isa<ir::IndexType>()) {
70+
return phi::DataType::INT32;
6971
} else if (dtype.isa<ir::BoolType>()) {
7072
return phi::DataType::BOOL;
7173
} else if (dtype.isa<ir::Complex64Type>()) {
@@ -79,6 +81,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
7981
}
8082
}
8183

84+
// use phi::DataType::INT32 for IndexType from builtin type to phi::DataType,
85+
// but only use INT32 not IndexType from phi::DataType type to builtin type.
8286
static inline ir::Type TransToIrDataType(phi::DataType dtype,
8387
ir::IrContext* ctx = nullptr) {
8488
if (ctx == nullptr) {

paddle/ir/core/builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ BFloat16Type Builder::bfloat16_type() { return BFloat16Type::get(context_); }
4949
Float32Type Builder::float32_type() { return Float32Type::get(context_); }
5050

5151
Float64Type Builder::float64_type() { return Float64Type::get(context_); }
52+
IndexType Builder::index_type() { return IndexType::get(context_); }
5253
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
5354
BoolType Builder::bool_type() { return BoolType::get(context_); }
5455
Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); }

paddle/ir/core/builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class BFloat16Type;
2929
class Float32Type;
3030
class Float64Type;
3131
class Int16Type;
32+
class IndexType;
3233
class BoolType;
3334
class Complex64Type;
3435
class Complex128Type;
@@ -114,6 +115,7 @@ class Builder {
114115
IR_API Int8Type int8_type();
115116
IR_API VectorType vec_type(const std::vector<Type> &);
116117
IR_API BFloat16Type bfloat16_type();
118+
IR_API IndexType index_type();
117119
IR_API Float32Type float32_type();
118120
IR_API Float64Type float64_type();
119121
IR_API Int16Type int16_type();

paddle/ir/core/builtin_dialect.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void BuiltinDialect::initialize() {
3434
Int16Type,
3535
Int32Type,
3636
Int64Type,
37+
IndexType,
3738
BoolType,
3839
Complex64Type,
3940
Complex128Type,

paddle/ir/core/builtin_type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float64Type)
2929
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type)
3030
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type)
3131
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type)
32+
IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType)
3233
IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType)
3334
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type)
3435
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type)

paddle/ir/core/builtin_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class IR_API VectorType : public Type {
7373
__macro(Int16Type); \
7474
__macro(Int32Type); \
7575
__macro(Int64Type); \
76+
__macro(IndexType); \
7677
__macro(BoolType); \
7778
__macro(Complex64Type); \
7879
__macro(Complex128Type);
@@ -95,5 +96,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type)
9596
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type)
9697
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type)
9798
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType)
99+
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType)
98100
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type)
99101
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type)

paddle/ir/core/ir_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class IrContextImpl {
156156
Float16Type fp16_type;
157157
Float32Type fp32_type;
158158
Float64Type fp64_type;
159+
IndexType index_type;
159160
UInt8Type uint8_type;
160161
Int8Type int8_type;
161162
Int16Type int16_type;
@@ -203,6 +204,7 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
203204
impl_->int16_type = TypeManager::get<Int16Type>(this);
204205
impl_->int32_type = TypeManager::get<Int32Type>(this);
205206
impl_->int64_type = TypeManager::get<Int64Type>(this);
207+
impl_->index_type = TypeManager::get<IndexType>(this);
206208
impl_->bool_type = TypeManager::get<BoolType>(this);
207209
impl_->complex64_type = TypeManager::get<Complex64Type>(this);
208210
impl_->complex128_type = TypeManager::get<Complex128Type>(this);
@@ -343,6 +345,8 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }
343345

344346
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }
345347

348+
IndexType IndexType::get(IrContext *ctx) { return ctx->impl().index_type; }
349+
346350
Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; }
347351

348352
UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; }

paddle/ir/core/ir_printer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void BasicIrPrinter::PrintType(Type type) {
5959
os << "i32";
6060
} else if (type.isa<Int64Type>()) {
6161
os << "i64";
62+
} else if (type.isa<IndexType>()) {
63+
os << "index";
6264
} else if (type.isa<Complex64Type>()) {
6365
os << "c64";
6466
} else if (type.isa<Complex128Type>()) {

paddle/ir/core/type.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,27 @@ IR_API std::ostream &operator<<(std::ostream &os, Type type);
9797

9898
} // namespace ir
9999

100+
///
101+
/// \brief This class represents the base of a type interface.
102+
///
103+
104+
// template <typename ConcreteType>
105+
// class TypeInterface : public ir::DialectInterface<ConcreteType, Type> {
106+
// public:
107+
// using Base = TypeInterface<ConcreteType>;
108+
// using DialectInterfaceBase = ir::DialectInterface<ConcreteType, Type>;
109+
// using DialectInterfaceBase::Base;
110+
111+
// private:
112+
// /// Returns the impl interface instance for the given type.
113+
// static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
114+
// return type.getAbstractType().getInterface<ConcreteType>();
115+
// }
116+
117+
// /// Allow access to 'getInterfaceFor'.
118+
// friend InterfaceBase;
119+
// };
120+
100121
namespace std {
101122
///
102123
/// \brief Enable hashing Type.

test/cpp/ir/core/ir_builder_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TEST(builder_test, type_api) {
3131
EXPECT_EQ(ir::BFloat16Type::get(&ctx), builder.bfloat16_type());
3232
EXPECT_EQ(ir::Float32Type::get(&ctx), builder.float32_type());
3333
EXPECT_EQ(ir::Float64Type::get(&ctx), builder.float64_type());
34+
EXPECT_EQ(ir::IndexType::get(&ctx), builder.index_type());
3435
EXPECT_EQ(ir::Int16Type::get(&ctx), builder.int16_type());
3536
EXPECT_EQ(ir::BoolType::get(&ctx), builder.bool_type());
3637
EXPECT_EQ(ir::Complex64Type::get(&ctx), builder.complex64_type());

test/cpp/ir/core/ir_type_converter_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,23 @@ TEST(TypeConverterTest, paramterless_type) {
6565
ir::Complex64Type,
6666
ir::Complex128Type>();
6767
}
68+
69+
void test_index_type() {
70+
ir::IrContext* ctx = ir::IrContext::Instance();
71+
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
72+
73+
ir::Type type = ir::IndexType::get(ctx);
74+
std::stringstream ss;
75+
ss << type;
76+
EXPECT_GT(ss.str().size(), 0u);
77+
EXPECT_EQ(ss.str(), "index");
78+
EXPECT_NE(ss.str(), "<<NULL TYPE>>");
79+
phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type);
80+
auto& type_translator = paddle::translator::TypeTranslator::instance();
81+
paddle::framework::VarDesc empty_var_desc("empty");
82+
auto proto_type = paddle::framework::TransToProtoVarType(phi_type);
83+
ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc);
84+
EXPECT_EQ(paddle::dialect::TransToIrDataType(phi_type), final_type);
85+
}
86+
87+
TEST(IndexTypeConverterTest, index_type) { test_index_type(); }

test/cpp/ir/core/type_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ TEST(type_test, built_in_type) {
8989
&ir::AbstractType::lookup(bfp16_1.type_id(), ctx));
9090
EXPECT_EQ(ir::BFloat16Type::classof(bfp16_1), 1);
9191

92+
ir::Type index_1 = ir::IndexType::get(ctx);
93+
ir::Type index_2 = ir::IndexType::get(ctx);
94+
EXPECT_EQ(index_1, index_2);
95+
EXPECT_EQ(index_1.type_id(), index_2.type_id());
96+
EXPECT_EQ(&index_1.abstract_type(),
97+
&ir::AbstractType::lookup(index_1.type_id(), ctx));
98+
EXPECT_EQ(ir::IndexType::classof(index_1), 1);
99+
92100
ir::Type fp16_1 = ir::Float16Type::get(ctx);
93101
ir::Type fp16_2 = ir::Float16Type::get(ctx);
94102
EXPECT_EQ(fp16_1, fp16_2);

0 commit comments

Comments
 (0)