Skip to content

Commit 02e6347

Browse files
authored
[New IR]Add attrs Interface for Python (#55974)
* add attrs and dtype interface * fix compile bugs * fix some bugs * fix windows bugs
1 parent 6b10c0e commit 02e6347

File tree

18 files changed

+642
-117
lines changed

18 files changed

+642
-117
lines changed

paddle/fluid/ir/dialect/pd_api.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,34 @@ ir::OpResult mean(ir::OpResult x, std::vector<int64_t> axis, bool keepdim) {
2323
paddle::dialect::MeanOp mean_op =
2424
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanOp>(
2525
x, axis, keepdim);
26-
return mean_op.result(0);
26+
return mean_op.out();
27+
}
28+
29+
ir::OpResult sum(ir::OpResult x,
30+
std::vector<int64_t> axis,
31+
phi::DataType dtype,
32+
bool keepdim) {
33+
paddle::dialect::SumOp sum_op =
34+
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SumOp>(
35+
x, axis, dtype, keepdim);
36+
return sum_op.out();
37+
}
38+
39+
ir::OpResult divide(ir::OpResult x, ir::OpResult y) {
40+
paddle::dialect::DivideOp divide_op =
41+
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::DivideOp>(x,
42+
y);
43+
return divide_op.out();
44+
}
45+
46+
ir::OpResult full(std::vector<int64_t> shape,
47+
float value,
48+
phi::DataType dtype,
49+
phi::Place place) {
50+
paddle::dialect::FullOp full_op =
51+
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::FullOp>(
52+
shape, value, dtype, place);
53+
return full_op.out();
2754
}
2855

2956
} // namespace dialect

paddle/fluid/ir/dialect/pd_api.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <vector>
1818

1919
#include "paddle/ir/core/value.h"
20+
#include "paddle/phi/common/data_type.h"
21+
#include "paddle/phi/common/place.h"
2022

2123
namespace paddle {
2224
namespace dialect {
@@ -25,5 +27,17 @@ ir::OpResult mean(ir::OpResult x,
2527
std::vector<int64_t> axis = {},
2628
bool keepdim = false);
2729

30+
ir::OpResult sum(ir::OpResult x,
31+
std::vector<int64_t> axis = {},
32+
phi::DataType dtype = phi::DataType::UNDEFINED,
33+
bool keepdim = false);
34+
35+
ir::OpResult divide(ir::OpResult x, ir::OpResult y);
36+
37+
ir::OpResult full(std::vector<int64_t> shape,
38+
float value,
39+
phi::DataType dtype = phi::DataType::FLOAT32,
40+
phi::Place place = phi::CPUPlace());
41+
2842
} // namespace dialect
2943
} // namespace paddle

paddle/fluid/ir/dialect/utils.h

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,35 @@
1616

1717
#include "paddle/fluid/framework/convert_utils.h"
1818
#include "paddle/fluid/framework/data_type.h"
19+
#include "paddle/fluid/ir/dialect/pd_attribute.h"
1920
#include "paddle/fluid/ir/dialect/pd_type_storage.h"
2021
#include "paddle/ir/core/builtin_attribute.h"
2122
#include "paddle/ir/core/builtin_type.h"
23+
#include "paddle/phi/common/int_array.h"
2224
#include "paddle/phi/common/scalar.h"
2325

2426
namespace paddle {
2527
namespace dialect {
28+
29+
using VariantType = paddle::variant<bool,
30+
int,
31+
int64_t,
32+
float,
33+
double,
34+
std::string,
35+
std::vector<bool>,
36+
std::vector<int>,
37+
std::vector<int64_t>,
38+
std::vector<float>,
39+
std::vector<double>,
40+
std::vector<std::string>,
41+
phi::Scalar,
42+
std::vector<phi::Scalar>,
43+
phi::IntArray,
44+
phi::DataType,
45+
phi::DataLayout,
46+
phi::Place>;
47+
2648
// TODO(zhangbo): The builtin type needs to cover all data types of
2749
// phi::DataType.
2850
static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
@@ -58,7 +80,7 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
5880
}
5981

6082
static inline ir::Type TransToIrDataType(phi::DataType dtype,
61-
ir::IrContext *ctx = nullptr) {
83+
ir::IrContext* ctx = nullptr) {
6284
if (ctx == nullptr) {
6385
ctx = ir::IrContext::Instance();
6486
}
@@ -96,7 +118,7 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype,
96118
}
97119

98120
static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
99-
ir::IrContext *ctx = nullptr) {
121+
ir::IrContext* ctx = nullptr) {
100122
if (ctx == nullptr) {
101123
ctx = ir::IrContext::Instance();
102124
}
@@ -119,5 +141,155 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
119141
}
120142
}
121143

144+
enum class AttrType {
145+
UNDEFINED = 0,
146+
BOOL,
147+
INT32,
148+
INT64,
149+
150+
FLOAT,
151+
DOUBLE,
152+
153+
ARRAY,
154+
INT_ARRAY,
155+
156+
SCALAR,
157+
DATA_TYPE,
158+
DATA_LAYOUT,
159+
PLACE,
160+
161+
STRING,
162+
163+
NUM_ATTR_TYPES,
164+
};
165+
166+
static inline AttrType GetAttributeType(const ir::Attribute& attr) {
167+
if (attr.isa<ir::BoolAttribute>()) {
168+
return AttrType::BOOL;
169+
} else if (attr.isa<ir::FloatAttribute>()) {
170+
return AttrType::FLOAT;
171+
} else if (attr.isa<ir::DoubleAttribute>()) {
172+
return AttrType::DOUBLE;
173+
} else if (attr.isa<ir::Int32Attribute>()) {
174+
return AttrType::INT32;
175+
} else if (attr.isa<ir::Int64Attribute>()) {
176+
return AttrType::INT64;
177+
} else if (attr.isa<ir::ArrayAttribute>()) {
178+
return AttrType::ARRAY;
179+
} else if (attr.isa<ir::StrAttribute>()) {
180+
return AttrType::STRING;
181+
} else if (attr.isa<paddle::dialect::IntArrayAttribute>()) {
182+
return AttrType::INT_ARRAY;
183+
} else if (attr.isa<paddle::dialect::DataTypeAttribute>()) {
184+
return AttrType::DATA_TYPE;
185+
} else if (attr.isa<paddle::dialect::PlaceAttribute>()) {
186+
return AttrType::PLACE;
187+
} else {
188+
PADDLE_THROW(phi::errors::Unimplemented(
189+
"Unsupported ir Attribute type when casting it into "
190+
"AttrType."));
191+
}
192+
}
193+
194+
static std::unordered_map<AttrType,
195+
std::function<VariantType(const ir::Attribute& attr)>>
196+
attr_cast_map = {
197+
{AttrType::BOOL,
198+
[](const ir::Attribute& attr) {
199+
return VariantType{attr.dyn_cast<ir::BoolAttribute>().data()};
200+
}},
201+
{AttrType::FLOAT,
202+
[](const ir::Attribute& attr) {
203+
return VariantType{attr.dyn_cast<ir::FloatAttribute>().data()};
204+
}},
205+
{AttrType::DOUBLE,
206+
[](const ir::Attribute& attr) {
207+
return VariantType{attr.dyn_cast<ir::DoubleAttribute>().data()};
208+
}},
209+
{AttrType::INT32,
210+
[](const ir::Attribute& attr) {
211+
return VariantType{attr.dyn_cast<ir::Int32Attribute>().data()};
212+
}},
213+
{AttrType::INT64,
214+
[](const ir::Attribute& attr) {
215+
return VariantType{attr.dyn_cast<ir::Int64Attribute>().data()};
216+
}},
217+
{AttrType::INT_ARRAY,
218+
[](const ir::Attribute& attr) {
219+
return VariantType{
220+
attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
221+
.data()
222+
.GetData()};
223+
}},
224+
{AttrType::STRING,
225+
[](const ir::Attribute& attr) {
226+
return VariantType{attr.dyn_cast<ir::StrAttribute>().AsString()};
227+
}},
228+
{AttrType::DATA_TYPE,
229+
[](const ir::Attribute& attr) {
230+
return VariantType{
231+
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data()};
232+
}},
233+
{AttrType::PLACE,
234+
[](const ir::Attribute& attr) {
235+
return VariantType{
236+
attr.dyn_cast<paddle::dialect::PlaceAttribute>().data()};
237+
}},
238+
{AttrType::ARRAY,
239+
[](const ir::Attribute& attr) {
240+
auto attr_vec = attr.dyn_cast<ir::ArrayAttribute>().AsVector();
241+
if (attr_vec.size() == 0) {
242+
return VariantType{std::vector<int>()};
243+
}
244+
AttrType element_type = GetAttributeType(attr_vec[0]);
245+
246+
if (element_type == AttrType::BOOL) {
247+
std::vector<bool> vec_bools;
248+
for (auto vec_element : attr_vec) {
249+
vec_bools.push_back(
250+
vec_element.dyn_cast<ir::BoolAttribute>().data());
251+
}
252+
return VariantType{vec_bools};
253+
} else if (element_type == AttrType::INT32) {
254+
std::vector<int> vec_int32;
255+
for (auto vec_element : attr_vec) {
256+
vec_int32.push_back(
257+
vec_element.dyn_cast<ir::Int32Attribute>().data());
258+
}
259+
return VariantType{vec_int32};
260+
} else if (element_type == AttrType::INT64) {
261+
std::vector<int64_t> vec_int64;
262+
for (auto vec_element : attr_vec) {
263+
vec_int64.push_back(
264+
vec_element.dyn_cast<ir::Int64Attribute>().data());
265+
}
266+
return VariantType{vec_int64};
267+
} else if (element_type == AttrType::FLOAT) {
268+
std::vector<float> vec_float;
269+
for (auto vec_element : attr_vec) {
270+
vec_float.push_back(
271+
vec_element.dyn_cast<ir::FloatAttribute>().data());
272+
}
273+
return VariantType{vec_float};
274+
} else if (element_type == AttrType::DOUBLE) {
275+
std::vector<double> vec_double;
276+
for (auto vec_element : attr_vec) {
277+
vec_double.push_back(
278+
vec_element.dyn_cast<ir::DoubleAttribute>().data());
279+
}
280+
return VariantType{vec_double};
281+
} else {
282+
PADDLE_THROW(phi::errors::Unimplemented(
283+
"Unsupported ir Attribute type when casting it into "
284+
"vector."));
285+
}
286+
}},
287+
};
288+
289+
static inline VariantType GetAttributeData(const ir::Attribute& attr) {
290+
AttrType attr_type = GetAttributeType(attr);
291+
return attr_cast_map[attr_type](attr);
292+
}
293+
122294
} // namespace dialect
123295
} // namespace paddle

paddle/fluid/pybind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ set(PYBIND_SRCS
124124
pybind.cc
125125
imperative.cc
126126
inference_api.cc
127+
ops_api.cc
127128
static_op_function.cc
128129
ir.cc
129130
graph.cc

paddle/fluid/pybind/eager_utils.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ extern PyTypeObject* p_string_tensor_type;
5151
extern PyTypeObject* g_framework_scope_pytype;
5252
extern PyTypeObject* g_ir_opresult_pytype;
5353
extern PyTypeObject* g_vartype_pytype;
54+
extern PyTypeObject* g_data_type_pytype;
5455
extern PyTypeObject* g_place_pytype;
5556
extern PyTypeObject* g_cudaplace_pytype;
5657
extern PyTypeObject* g_cpuplace_pytype;
@@ -644,6 +645,24 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
644645
return dtype;
645646
}
646647

648+
paddle::DataType CastPyArg2DataTypeDirectly(PyObject* obj,
649+
const std::string& op_type,
650+
ssize_t arg_pos) {
651+
paddle::DataType dtype;
652+
if (PyObject_TypeCheck(obj, g_data_type_pytype)) {
653+
dtype = ::pybind11::handle(obj).cast<paddle::DataType>();
654+
} else {
655+
PADDLE_THROW(platform::errors::InvalidArgument(
656+
"%s: argument (position %d) must be "
657+
"one of core.VarDesc.VarType, "
658+
"but got %s",
659+
op_type,
660+
arg_pos + 1,
661+
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
662+
}
663+
return dtype;
664+
}
665+
647666
paddle::framework::Vocab CastPyArg2Vocab(PyObject* obj, ssize_t arg_pos) {
648667
if (PyDict_Check(obj)) {
649668
paddle::framework::Vocab vocab;

paddle/fluid/pybind/eager_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ paddle::DataType CastPyArg2DataType(PyObject* obj,
306306
const std::string& op_type,
307307
ssize_t arg_pos);
308308

309+
paddle::DataType CastPyArg2DataTypeDirectly(PyObject* obj,
310+
const std::string& op_type,
311+
ssize_t arg_pos);
312+
309313
#ifdef PADDLE_WITH_DISTRIBUTE
310314
std::shared_ptr<phi::distributed::auto_parallel::TensorDistAttr>
311315
CastPyArg2DistAttr(PyObject* obj, ssize_t arg_pos);

0 commit comments

Comments
 (0)