Skip to content

Commit b0a604c

Browse files
authored
[IR] Type system stage3: add class Dialect (#50959)
* add dialect * add some interface for dialect * add some dialect interfaces for class Type * set WITH_NEWIR=OFF * refine code by comment * polish code * refine include style * refine log for debug
1 parent 8f156fd commit b0a604c

13 files changed

+462
-64
lines changed

paddle/ir/builtin_dialect.cc

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/ir/builtin_dialect.h"
16+
#include "paddle/ir/builtin_type.h"
17+
18+
namespace ir {
19+
BuiltinDialect::BuiltinDialect(ir::IrContext *context)
20+
: ir::Dialect(name(), context, ir::TypeId::get<BuiltinDialect>()) {
21+
initialize();
22+
}
23+
24+
void BuiltinDialect::initialize() {
25+
// Register all built-in types defined in builtin_type.h.
26+
RegisterTypes<GET_BUILT_IN_TYPE_LIST>();
27+
}
28+
29+
} // namespace ir

paddle/ir/builtin_dialect.h

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/ir/dialect.h"
18+
19+
namespace ir {
20+
///
21+
/// \brief Built-in Dialect: automatically registered into global IrContext,
22+
/// all built-in types defined in builtin_type.h will be registered in this
23+
/// Dialect.
24+
///
25+
class BuiltinDialect : public ir::Dialect {
26+
public:
27+
explicit BuiltinDialect(ir::IrContext *context);
28+
///
29+
/// \brief Each Dialect needs to provide a name function to return the name of
30+
/// the Dialect.
31+
///
32+
/// \return The name of this Dialect.
33+
///
34+
static const char *name() { return "builtin"; }
35+
36+
private:
37+
void initialize();
38+
};
39+
40+
} // namespace ir

paddle/ir/builtin_type.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717
#include "paddle/ir/type.h"
1818

1919
namespace ir {
20+
///
21+
/// \brief This macro is used to get a list of all built-in types in this file.
22+
///
23+
#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type
24+
2025
///
2126
/// \brief Definitions of built-in type classes. The built-in type object get
22-
/// method is as follows: Type fp32 = Float32Type::get(ctx);
27+
/// method is as follows:
28+
/// \code{cpp}
29+
/// ir::IrContext *ctx = ir::IrContext::Instance();
30+
/// Type fp32 = Float32Type::get(ctx);
31+
/// \endcode
2332
///
2433
class Float32Type : public ir::Type {
2534
public:

paddle/ir/dialect.cc

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/ir/dialect.h"
16+
17+
namespace ir {
18+
Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id)
19+
: name_(std::move(name)), context_(context), id_(id) {}
20+
21+
void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
22+
ir::AbstractType *new_abstract_type =
23+
new ir::AbstractType(std::move(abstract_type));
24+
this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(),
25+
new_abstract_type);
26+
}
27+
28+
} // namespace ir

paddle/ir/dialect.h

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/ir/ir_context.h"
18+
#include "paddle/ir/type_base.h"
19+
20+
namespace ir {
21+
///
22+
/// \brief Dialect can basically be understood as a namespace. In Dialect, we
23+
/// can define a series of types, operations, etc. An instance of the dialect
24+
/// object will be loaded into the global IrContext. Specific compilers only
25+
/// need to combine existing dialects and add their own extensions or
26+
/// customizations.
27+
///
28+
class Dialect {
29+
public:
30+
Dialect(std::string name, ir::IrContext *context, ir::TypeId id);
31+
32+
const std::string &name() const { return name_; }
33+
34+
ir::IrContext *ir_context() const { return context_; }
35+
36+
ir::TypeId id() const { return id_; }
37+
38+
///
39+
/// \brief Register all types contained in the template parameter Args.
40+
/// To register only one Type, you can use the RegisterType template function.
41+
///
42+
template <typename... Args>
43+
void RegisterTypes() {
44+
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
45+
}
46+
47+
///
48+
/// \brief Register type of class T.
49+
///
50+
template <typename T>
51+
void RegisterType() {
52+
VLOG(4) << "Type registered into Dialect. --->";
53+
ir::AbstractType *abstract_type =
54+
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
55+
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
56+
abstract_type);
57+
ir::TypeManager::RegisterType<T>(this->ir_context());
58+
VLOG(4) << "----------------------------------";
59+
}
60+
61+
///
62+
/// \brief Register abstract_type into context.
63+
/// NOTE: It's not recommended to use this interface directly. This interface
64+
/// only registers abstract_type. To register TypeStorage into context, you
65+
/// need to call ir::TypeManager::RegisterType<T>() additionally,
66+
/// RegisterType<T>() is recommended to use.
67+
///
68+
void RegisterType(ir::AbstractType &&abstract_type);
69+
70+
private:
71+
std::string name_;
72+
73+
ir::IrContext *context_; // not owned
74+
75+
ir::TypeId id_;
76+
};
77+
} // namespace ir

paddle/ir/ir_context.cc

+85-23
Original file line numberDiff line numberDiff line change
@@ -12,61 +12,95 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/ir/ir_context.h"
16+
1517
#include <unordered_map>
1618

19+
#include "paddle/ir/builtin_dialect.h"
1720
#include "paddle/ir/builtin_type.h"
18-
#include "paddle/ir/ir_context.h"
21+
#include "paddle/ir/dialect.h"
1922
#include "paddle/ir/spin_lock.h"
2023
#include "paddle/ir/type_base.h"
2124

2225
namespace ir {
23-
// The implementation class of the IrContext class
26+
// The implementation class of the IrContext class, cache registered
27+
// AbstractType, TypeStorage, Dialect.
2428
class IrContextImpl {
2529
public:
2630
IrContextImpl() {}
2731

2832
~IrContextImpl() {
29-
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
30-
for (auto abstract_type_map : registed_abstract_types_) {
33+
std::lock_guard<ir::SpinLock> guard(destructor_lock_);
34+
for (auto &abstract_type_map : registed_abstract_types_) {
3135
delete abstract_type_map.second;
3236
}
3337
registed_abstract_types_.clear();
38+
39+
for (auto &dialect_map : registed_dialect_) {
40+
delete dialect_map.second;
41+
}
42+
registed_dialect_.clear();
3443
}
3544

3645
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
3746
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
38-
VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash="
47+
VLOG(4) << "Register an abstract_type of: [TypeId_hash="
3948
<< std::hash<ir::TypeId>()(type_id)
4049
<< ", AbstractType_ptr=" << abstract_type << "].";
4150
registed_abstract_types_.emplace(type_id, abstract_type);
4251
}
4352

44-
AbstractType *lookup(ir::TypeId type_id) {
53+
AbstractType *GetAbstractType(ir::TypeId type_id) {
4554
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
4655
auto iter = registed_abstract_types_.find(type_id);
47-
if (iter == registed_abstract_types_.end()) {
48-
VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash="
49-
<< std::hash<ir::TypeId>()(type_id) << "].";
50-
return nullptr;
51-
} else {
52-
VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash="
56+
if (iter != registed_abstract_types_.end()) {
57+
VLOG(4) << "Fonund a cached abstract_type of: [TypeId_hash="
5358
<< std::hash<ir::TypeId>()(type_id)
5459
<< ", AbstractType_ptr=" << iter->second << "].";
5560
return iter->second;
5661
}
62+
LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
63+
<< std::hash<ir::TypeId>()(type_id) << "].";
64+
return nullptr;
5765
}
5866

59-
ir::SpinLock registed_abstract_types_lock_;
67+
void RegisterDialect(std::string name, Dialect *dialect) {
68+
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
69+
VLOG(4) << "Register a dialect of: [name=" << name
70+
<< ", dialect_ptr=" << dialect << "].";
71+
registed_dialect_.emplace(name, dialect);
72+
}
73+
74+
Dialect *GetDialect(std::string name) {
75+
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
76+
auto iter = registed_dialect_.find(name);
77+
if (iter != registed_dialect_.end()) {
78+
VLOG(4) << "Fonund a cached dialect of: [name=" << name
79+
<< ", dialect_ptr=" << iter->second << "].";
80+
return iter->second;
81+
}
82+
LOG(WARNING) << "No cache fonund dialect of: [name=" << name << "].";
83+
return nullptr;
84+
}
6085

6186
// Cached AbstractType instances.
6287
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
6388

89+
ir::SpinLock registed_abstract_types_lock_;
90+
6491
// TypeStorage uniquer and cache instances.
6592
StorageManager registed_storage_manager_;
6693

67-
// Some built-in type.
94+
// The dialcet registered in the context.
95+
std::unordered_map<std::string, Dialect *> registed_dialect_;
96+
97+
ir::SpinLock registed_dialect_lock_;
98+
99+
// Some built-in types.
68100
Float32Type fp32_type;
69101
Int32Type int32_type;
102+
103+
ir::SpinLock destructor_lock_;
70104
};
71105

72106
IrContext *IrContext::Instance() {
@@ -75,13 +109,12 @@ IrContext *IrContext::Instance() {
75109
}
76110

77111
IrContext::IrContext() : impl_(new IrContextImpl()) {
78-
VLOG(4) << "IrContext register built-in type...";
79-
REGISTER_TYPE_2_IRCONTEXT(Float32Type, this);
112+
VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
113+
GetOrRegisterDialect<BuiltinDialect>();
114+
VLOG(4) << "==============================================";
115+
80116
impl_->fp32_type = TypeManager::get<Float32Type>(this);
81-
VLOG(4) << "Float32Type registration complete";
82-
REGISTER_TYPE_2_IRCONTEXT(Int32Type, this);
83117
impl_->int32_type = TypeManager::get<Int32Type>(this);
84-
VLOG(4) << "Int32Type registration complete";
85118
}
86119

87120
void IrContext::RegisterAbstractType(ir::TypeId type_id,
@@ -98,12 +131,41 @@ std::unordered_map<TypeId, AbstractType *>
98131
return impl().registed_abstract_types_;
99132
}
100133

101-
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
102-
VLOG(4) << "Lookup abstract type [TypeId_hash="
103-
<< std::hash<ir::TypeId>()(type_id) << "] from IrContext [ptr=" << ctx
134+
Dialect *IrContext::GetOrRegisterDialect(
135+
std::string dialect_name, std::function<Dialect *()> constructor) {
136+
VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
104137
<< "].";
138+
Dialect *dialect = impl().GetDialect(dialect_name);
139+
if (dialect == nullptr) {
140+
VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
141+
<< "].";
142+
dialect = constructor();
143+
impl().RegisterDialect(dialect_name, dialect);
144+
}
145+
return dialect;
146+
}
147+
148+
std::vector<Dialect *> IrContext::GetRegisteredDialects() {
149+
std::vector<Dialect *> result;
150+
for (auto dialect_map : impl().registed_dialect_) {
151+
result.push_back(dialect_map.second);
152+
}
153+
return result;
154+
}
155+
156+
Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
157+
for (auto dialect_map : impl().registed_dialect_) {
158+
if (dialect_map.first == dialect_name) {
159+
return dialect_map.second;
160+
}
161+
}
162+
LOG(WARNING) << "No dialect registered for " << dialect_name;
163+
return nullptr;
164+
}
165+
166+
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
105167
auto &impl = ctx->impl();
106-
AbstractType *abstract_type = impl.lookup(type_id);
168+
AbstractType *abstract_type = impl.GetAbstractType(type_id);
107169
if (abstract_type) {
108170
return *abstract_type;
109171
} else {

0 commit comments

Comments
 (0)