16
16
17
17
#include " paddle/fluid/framework/convert_utils.h"
18
18
#include " paddle/fluid/framework/data_type.h"
19
+ #include " paddle/fluid/ir/dialect/pd_attribute.h"
19
20
#include " paddle/fluid/ir/dialect/pd_type_storage.h"
20
21
#include " paddle/ir/core/builtin_attribute.h"
21
22
#include " paddle/ir/core/builtin_type.h"
23
+ #include " paddle/phi/common/int_array.h"
22
24
#include " paddle/phi/common/scalar.h"
23
25
24
26
namespace paddle {
25
27
namespace dialect {
28
+
29
+ using VariantType = paddle::variant<bool ,
30
+ int ,
31
+ int64_t ,
32
+ float ,
33
+ double ,
34
+ std::string,
35
+ std::vector<bool >,
36
+ std::vector<int >,
37
+ std::vector<int64_t >,
38
+ std::vector<float >,
39
+ std::vector<double >,
40
+ std::vector<std::string>,
41
+ phi::Scalar,
42
+ std::vector<phi::Scalar>,
43
+ phi::IntArray,
44
+ phi::DataType,
45
+ phi::DataLayout,
46
+ phi::Place>;
47
+
26
48
// TODO(zhangbo): The builtin type needs to cover all data types of
27
49
// phi::DataType.
28
50
static inline phi::DataType TransToPhiDataType (ir::Type dtype) {
@@ -58,7 +80,7 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
58
80
}
59
81
60
82
static inline ir::Type TransToIrDataType (phi::DataType dtype,
61
- ir::IrContext * ctx = nullptr ) {
83
+ ir::IrContext* ctx = nullptr ) {
62
84
if (ctx == nullptr ) {
63
85
ctx = ir::IrContext::Instance ();
64
86
}
@@ -96,7 +118,7 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype,
96
118
}
97
119
98
120
static inline ir::Attribute TransToIrAttribute (phi::Scalar scalar,
99
- ir::IrContext * ctx = nullptr ) {
121
+ ir::IrContext* ctx = nullptr ) {
100
122
if (ctx == nullptr ) {
101
123
ctx = ir::IrContext::Instance ();
102
124
}
@@ -119,5 +141,155 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
119
141
}
120
142
}
121
143
144
+ enum class AttrType {
145
+ UNDEFINED = 0 ,
146
+ BOOL,
147
+ INT32,
148
+ INT64,
149
+
150
+ FLOAT,
151
+ DOUBLE,
152
+
153
+ ARRAY,
154
+ INT_ARRAY,
155
+
156
+ SCALAR,
157
+ DATA_TYPE,
158
+ DATA_LAYOUT,
159
+ PLACE,
160
+
161
+ STRING,
162
+
163
+ NUM_ATTR_TYPES,
164
+ };
165
+
166
+ static inline AttrType GetAttributeType (const ir::Attribute& attr) {
167
+ if (attr.isa <ir::BoolAttribute>()) {
168
+ return AttrType::BOOL;
169
+ } else if (attr.isa <ir::FloatAttribute>()) {
170
+ return AttrType::FLOAT;
171
+ } else if (attr.isa <ir::DoubleAttribute>()) {
172
+ return AttrType::DOUBLE;
173
+ } else if (attr.isa <ir::Int32Attribute>()) {
174
+ return AttrType::INT32;
175
+ } else if (attr.isa <ir::Int64Attribute>()) {
176
+ return AttrType::INT64;
177
+ } else if (attr.isa <ir::ArrayAttribute>()) {
178
+ return AttrType::ARRAY;
179
+ } else if (attr.isa <ir::StrAttribute>()) {
180
+ return AttrType::STRING;
181
+ } else if (attr.isa <paddle::dialect::IntArrayAttribute>()) {
182
+ return AttrType::INT_ARRAY;
183
+ } else if (attr.isa <paddle::dialect::DataTypeAttribute>()) {
184
+ return AttrType::DATA_TYPE;
185
+ } else if (attr.isa <paddle::dialect::PlaceAttribute>()) {
186
+ return AttrType::PLACE;
187
+ } else {
188
+ PADDLE_THROW (phi::errors::Unimplemented (
189
+ " Unsupported ir Attribute type when casting it into "
190
+ " AttrType." ));
191
+ }
192
+ }
193
+
194
+ static std::unordered_map<AttrType,
195
+ std::function<VariantType(const ir::Attribute& attr)>>
196
+ attr_cast_map = {
197
+ {AttrType::BOOL,
198
+ [](const ir::Attribute& attr) {
199
+ return VariantType{attr.dyn_cast <ir::BoolAttribute>().data ()};
200
+ }},
201
+ {AttrType::FLOAT,
202
+ [](const ir::Attribute& attr) {
203
+ return VariantType{attr.dyn_cast <ir::FloatAttribute>().data ()};
204
+ }},
205
+ {AttrType::DOUBLE,
206
+ [](const ir::Attribute& attr) {
207
+ return VariantType{attr.dyn_cast <ir::DoubleAttribute>().data ()};
208
+ }},
209
+ {AttrType::INT32,
210
+ [](const ir::Attribute& attr) {
211
+ return VariantType{attr.dyn_cast <ir::Int32Attribute>().data ()};
212
+ }},
213
+ {AttrType::INT64,
214
+ [](const ir::Attribute& attr) {
215
+ return VariantType{attr.dyn_cast <ir::Int64Attribute>().data ()};
216
+ }},
217
+ {AttrType::INT_ARRAY,
218
+ [](const ir::Attribute& attr) {
219
+ return VariantType{
220
+ attr.dyn_cast <paddle::dialect::IntArrayAttribute>()
221
+ .data ()
222
+ .GetData ()};
223
+ }},
224
+ {AttrType::STRING,
225
+ [](const ir::Attribute& attr) {
226
+ return VariantType{attr.dyn_cast <ir::StrAttribute>().AsString ()};
227
+ }},
228
+ {AttrType::DATA_TYPE,
229
+ [](const ir::Attribute& attr) {
230
+ return VariantType{
231
+ attr.dyn_cast <paddle::dialect::DataTypeAttribute>().data ()};
232
+ }},
233
+ {AttrType::PLACE,
234
+ [](const ir::Attribute& attr) {
235
+ return VariantType{
236
+ attr.dyn_cast <paddle::dialect::PlaceAttribute>().data ()};
237
+ }},
238
+ {AttrType::ARRAY,
239
+ [](const ir::Attribute& attr) {
240
+ auto attr_vec = attr.dyn_cast <ir::ArrayAttribute>().AsVector ();
241
+ if (attr_vec.size () == 0 ) {
242
+ return VariantType{std::vector<int >()};
243
+ }
244
+ AttrType element_type = GetAttributeType (attr_vec[0 ]);
245
+
246
+ if (element_type == AttrType::BOOL) {
247
+ std::vector<bool > vec_bools;
248
+ for (auto vec_element : attr_vec) {
249
+ vec_bools.push_back (
250
+ vec_element.dyn_cast <ir::BoolAttribute>().data ());
251
+ }
252
+ return VariantType{vec_bools};
253
+ } else if (element_type == AttrType::INT32) {
254
+ std::vector<int > vec_int32;
255
+ for (auto vec_element : attr_vec) {
256
+ vec_int32.push_back (
257
+ vec_element.dyn_cast <ir::Int32Attribute>().data ());
258
+ }
259
+ return VariantType{vec_int32};
260
+ } else if (element_type == AttrType::INT64) {
261
+ std::vector<int64_t > vec_int64;
262
+ for (auto vec_element : attr_vec) {
263
+ vec_int64.push_back (
264
+ vec_element.dyn_cast <ir::Int64Attribute>().data ());
265
+ }
266
+ return VariantType{vec_int64};
267
+ } else if (element_type == AttrType::FLOAT) {
268
+ std::vector<float > vec_float;
269
+ for (auto vec_element : attr_vec) {
270
+ vec_float.push_back (
271
+ vec_element.dyn_cast <ir::FloatAttribute>().data ());
272
+ }
273
+ return VariantType{vec_float};
274
+ } else if (element_type == AttrType::DOUBLE) {
275
+ std::vector<double > vec_double;
276
+ for (auto vec_element : attr_vec) {
277
+ vec_double.push_back (
278
+ vec_element.dyn_cast <ir::DoubleAttribute>().data ());
279
+ }
280
+ return VariantType{vec_double};
281
+ } else {
282
+ PADDLE_THROW (phi::errors::Unimplemented (
283
+ " Unsupported ir Attribute type when casting it into "
284
+ " vector." ));
285
+ }
286
+ }},
287
+ };
288
+
289
+ static inline VariantType GetAttributeData (const ir::Attribute& attr) {
290
+ AttrType attr_type = GetAttributeType (attr);
291
+ return attr_cast_map[attr_type](attr);
292
+ }
293
+
122
294
} // namespace dialect
123
295
} // namespace paddle
0 commit comments