12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include " paddle/ir/ir_context.h"
16
+
15
17
#include < unordered_map>
16
18
19
+ #include " paddle/ir/builtin_dialect.h"
17
20
#include " paddle/ir/builtin_type.h"
18
- #include " paddle/ir/ir_context .h"
21
+ #include " paddle/ir/dialect .h"
19
22
#include " paddle/ir/spin_lock.h"
20
23
#include " paddle/ir/type_base.h"
21
24
22
25
namespace ir {
23
- // The implementation class of the IrContext class
26
+ // The implementation class of the IrContext class, cache registered
27
+ // AbstractType, TypeStorage, Dialect.
24
28
class IrContextImpl {
25
29
public:
26
30
IrContextImpl () {}
27
31
28
32
~IrContextImpl () {
29
- std::lock_guard<ir::SpinLock> guard (registed_abstract_types_lock_ );
30
- for (auto abstract_type_map : registed_abstract_types_) {
33
+ std::lock_guard<ir::SpinLock> guard (destructor_lock_ );
34
+ for (auto & abstract_type_map : registed_abstract_types_) {
31
35
delete abstract_type_map.second ;
32
36
}
33
37
registed_abstract_types_.clear ();
38
+
39
+ for (auto &dialect_map : registed_dialect_) {
40
+ delete dialect_map.second ;
41
+ }
42
+ registed_dialect_.clear ();
34
43
}
35
44
36
45
void RegisterAbstractType (ir::TypeId type_id, AbstractType *abstract_type) {
37
46
std::lock_guard<ir::SpinLock> guard (registed_abstract_types_lock_);
38
- VLOG (4 ) << " IrContext register an abstract_type of: [TypeId_hash="
47
+ VLOG (4 ) << " Register an abstract_type of: [TypeId_hash="
39
48
<< std::hash<ir::TypeId>()(type_id)
40
49
<< " , AbstractType_ptr=" << abstract_type << " ]." ;
41
50
registed_abstract_types_.emplace (type_id, abstract_type);
42
51
}
43
52
44
- AbstractType *lookup (ir::TypeId type_id) {
53
+ AbstractType *GetAbstractType (ir::TypeId type_id) {
45
54
std::lock_guard<ir::SpinLock> guard (registed_abstract_types_lock_);
46
55
auto iter = registed_abstract_types_.find (type_id);
47
- if (iter == registed_abstract_types_.end ()) {
48
- VLOG (4 ) << " IrContext not fonund cached abstract_type of: [TypeId_hash="
49
- << std::hash<ir::TypeId>()(type_id) << " ]." ;
50
- return nullptr ;
51
- } else {
52
- VLOG (4 ) << " IrContext fonund a cached abstract_type of: [TypeId_hash="
56
+ if (iter != registed_abstract_types_.end ()) {
57
+ VLOG (4 ) << " Fonund a cached abstract_type of: [TypeId_hash="
53
58
<< std::hash<ir::TypeId>()(type_id)
54
59
<< " , AbstractType_ptr=" << iter->second << " ]." ;
55
60
return iter->second ;
56
61
}
62
+ LOG (WARNING) << " No cache found abstract_type of: [TypeId_hash="
63
+ << std::hash<ir::TypeId>()(type_id) << " ]." ;
64
+ return nullptr ;
57
65
}
58
66
59
- ir::SpinLock registed_abstract_types_lock_;
67
+ void RegisterDialect (std::string name, Dialect *dialect) {
68
+ std::lock_guard<ir::SpinLock> guard (registed_dialect_lock_);
69
+ VLOG (4 ) << " Register a dialect of: [name=" << name
70
+ << " , dialect_ptr=" << dialect << " ]." ;
71
+ registed_dialect_.emplace (name, dialect);
72
+ }
73
+
74
+ Dialect *GetDialect (std::string name) {
75
+ std::lock_guard<ir::SpinLock> guard (registed_dialect_lock_);
76
+ auto iter = registed_dialect_.find (name);
77
+ if (iter != registed_dialect_.end ()) {
78
+ VLOG (4 ) << " Fonund a cached dialect of: [name=" << name
79
+ << " , dialect_ptr=" << iter->second << " ]." ;
80
+ return iter->second ;
81
+ }
82
+ LOG (WARNING) << " No cache fonund dialect of: [name=" << name << " ]." ;
83
+ return nullptr ;
84
+ }
60
85
61
86
// Cached AbstractType instances.
62
87
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
63
88
89
+ ir::SpinLock registed_abstract_types_lock_;
90
+
64
91
// TypeStorage uniquer and cache instances.
65
92
StorageManager registed_storage_manager_;
66
93
67
- // Some built-in type.
94
+ // The dialcet registered in the context.
95
+ std::unordered_map<std::string, Dialect *> registed_dialect_;
96
+
97
+ ir::SpinLock registed_dialect_lock_;
98
+
99
+ // Some built-in types.
68
100
Float32Type fp32_type;
69
101
Int32Type int32_type;
102
+
103
+ ir::SpinLock destructor_lock_;
70
104
};
71
105
72
106
IrContext *IrContext::Instance () {
@@ -75,13 +109,12 @@ IrContext *IrContext::Instance() {
75
109
}
76
110
77
111
IrContext::IrContext () : impl_(new IrContextImpl()) {
78
- VLOG (4 ) << " IrContext register built-in type..." ;
79
- REGISTER_TYPE_2_IRCONTEXT (Float32Type, this );
112
+ VLOG (4 ) << " BuiltinDialect registered into IrContext. ===>" ;
113
+ GetOrRegisterDialect<BuiltinDialect>();
114
+ VLOG (4 ) << " ==============================================" ;
115
+
80
116
impl_->fp32_type = TypeManager::get<Float32Type>(this );
81
- VLOG (4 ) << " Float32Type registration complete" ;
82
- REGISTER_TYPE_2_IRCONTEXT (Int32Type, this );
83
117
impl_->int32_type = TypeManager::get<Int32Type>(this );
84
- VLOG (4 ) << " Int32Type registration complete" ;
85
118
}
86
119
87
120
void IrContext::RegisterAbstractType (ir::TypeId type_id,
@@ -98,12 +131,41 @@ std::unordered_map<TypeId, AbstractType *>
98
131
return impl ().registed_abstract_types_ ;
99
132
}
100
133
101
- const AbstractType & AbstractType::lookup (TypeId type_id, IrContext *ctx) {
102
- VLOG ( 4 ) << " Lookup abstract type [TypeId_hash= "
103
- << std::hash<ir::TypeId>()(type_id) << " ] from IrContext [ptr =" << ctx
134
+ Dialect * IrContext::GetOrRegisterDialect (
135
+ std::string dialect_name, std::function<Dialect *()> constructor) {
136
+ VLOG ( 4 ) << " Try to get or register a Dialect of: [name =" << dialect_name
104
137
<< " ]." ;
138
+ Dialect *dialect = impl ().GetDialect (dialect_name);
139
+ if (dialect == nullptr ) {
140
+ VLOG (4 ) << " Create and register a new Dialect of: [name=" << dialect_name
141
+ << " ]." ;
142
+ dialect = constructor ();
143
+ impl ().RegisterDialect (dialect_name, dialect);
144
+ }
145
+ return dialect;
146
+ }
147
+
148
+ std::vector<Dialect *> IrContext::GetRegisteredDialects () {
149
+ std::vector<Dialect *> result;
150
+ for (auto dialect_map : impl ().registed_dialect_ ) {
151
+ result.push_back (dialect_map.second );
152
+ }
153
+ return result;
154
+ }
155
+
156
+ Dialect *IrContext::GetRegisteredDialect (const std::string &dialect_name) {
157
+ for (auto dialect_map : impl ().registed_dialect_ ) {
158
+ if (dialect_map.first == dialect_name) {
159
+ return dialect_map.second ;
160
+ }
161
+ }
162
+ LOG (WARNING) << " No dialect registered for " << dialect_name;
163
+ return nullptr ;
164
+ }
165
+
166
+ const AbstractType &AbstractType::lookup (TypeId type_id, IrContext *ctx) {
105
167
auto &impl = ctx->impl ();
106
- AbstractType *abstract_type = impl.lookup (type_id);
168
+ AbstractType *abstract_type = impl.GetAbstractType (type_id);
107
169
if (abstract_type) {
108
170
return *abstract_type;
109
171
} else {
0 commit comments