Skip to content

[IR] finetune the StrAttribute interface. #55439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) {
// and so that they would be dispatched to host thread.
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
if (op_name == kCoalesceTensor &&
(!platform::is_xpu_place(place) ||
op->attribute<ir::BoolAttribute>("persist_output").data() == false) &&
Expand Down Expand Up @@ -77,7 +77,7 @@ PhiKernelInstruction::PhiKernelInstruction(
: InstructionBase(id, place) {
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name);

phi_op_name_ = op_name;
Expand Down Expand Up @@ -142,7 +142,7 @@ PhiKernelInstruction::PhiKernelInstruction(
VLOG(6) << "finish process infer meta context";

auto kernel_name =
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,8 @@ void BuildOpFuncList(
OpFuncNode op_func_node;
auto attr_map = (*it)->attributes();

auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
auto op_name =
attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
op_func_node.phi_op_name_ = op_name;

if (op_name == "builtin.combine" || op_name == "pd.feed" ||
Expand Down Expand Up @@ -986,7 +987,7 @@ void BuildOpFuncList(
&(op_func_node.infer_meta_context_));

auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir/dialect/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ void PhiKernelOp::Verify() {
}

std::string PhiKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
std::string PhiKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().data();
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
}
phi::KernelKey PhiKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
Expand Down
17 changes: 16 additions & 1 deletion paddle/fluid/ir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,14 @@ def gen_build_func_str(

GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data();
"""
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<ir::StrAttribute>().AsString();
"""
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name};
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].dyn_cast<{inner_type}>().data());
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).dyn_cast<{inner_type}>().{data_name}());
}}
"""
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
Expand All @@ -566,11 +569,15 @@ def gen_build_func_str(
# attr_type = "std::vector<int>"
if "ir::ArrayAttribute" in op_attribute_type_list[idx]:
inner_type = op_attribute_type_list[idx][19:-1]
data_name = "data"
if inner_type == "ir::StrAttribute":
data_name = "AsString"
get_attributes_str += (
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
attribute_name=op_attribute_name_list[idx],
inner_type=inner_type,
data_name=data_name,
)
)
elif (
Expand All @@ -593,6 +600,14 @@ def gen_build_func_str(
attribute_name=op_attribute_name_list[idx],
)
)
elif "ir::StrAttribute" in op_attribute_type_list[idx]:
get_attributes_str += (
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
attribute_name=op_attribute_name_list[idx],
attr_ir_type=op_attribute_type_list[idx],
)
)
else:
get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(),
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class PhiKernelAdaptor {
for (auto it = block->begin(); it != block->end(); ++it) {
auto attr_map = (*it)->attributes();

auto op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
auto op_name =
attr_map.at("op_name").dyn_cast<ir::StrAttribute>().AsString();

ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name);

Expand Down Expand Up @@ -104,7 +105,7 @@ class PhiKernelAdaptor {
infer_meta_impl->infer_meta_(&ctx);

auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void HandleForSpecialOp(
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}

if (op_name == "pd.fetch") {
Expand Down Expand Up @@ -244,7 +244,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();

auto value = op->operand(0);
// change opreand name to param_name
Expand All @@ -262,7 +262,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();
auto value = op->result(0);
value_2_var_name->emplace(value, param_name);
}
Expand Down Expand Up @@ -306,7 +306,7 @@ void HandleForInplaceOp(
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}

ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
Expand Down Expand Up @@ -356,8 +356,10 @@ void BuildScope(const ir::Block& block,

std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op_name = op->attributes()
.at("op_name")
.dyn_cast<ir::StrAttribute>()
.AsString();
}
VLOG(4) << "build op:" << op_name;

Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ void BuildPhiContext(
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().AsString());
} else if (attr_type_name ==
"ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<phi::Scalar> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
Expand All @@ -207,7 +207,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int32_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
Expand All @@ -222,7 +222,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<float> vec_res;
if (array_list.size() > 0) {
if (array_list[0].isa<ir::FloatAttribute>()) {
Expand All @@ -238,7 +238,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();

std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
Expand All @@ -255,7 +255,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();

std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
Expand Down Expand Up @@ -286,7 +286,7 @@ void BuildPhiContext(

// TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = inner_scope->FindVar("fetch");
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) {
std::string name = op->attributes()
.at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();
ir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.data();
.AsVector();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/builtin_transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DeadCodeEliminationPass : public ir::Pass {
get_parameter_op->attributes()
.at(get_parameter_op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data());
.AsString());
}
block->erase(*op);
}
Expand Down
64 changes: 53 additions & 11 deletions paddle/ir/core/builtin_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,69 @@
#include "paddle/ir/core/builtin_attribute.h"

namespace ir {
std::string StrAttribute::data() const { return storage()->GetAsKey(); }

uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); }
bool BoolAttribute::data() const { return storage()->data(); }

bool BoolAttribute::data() const { return storage()->GetAsKey(); }
float FloatAttribute::data() const { return storage()->data(); }

float FloatAttribute::data() const { return storage()->GetAsKey(); }
double DoubleAttribute::data() const { return storage()->data(); }

double DoubleAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32Attribute::data() const { return storage()->data(); }

int32_t Int32Attribute::data() const { return storage()->GetAsKey(); }
int64_t Int64Attribute::data() const { return storage()->data(); }

int64_t Int64Attribute::data() const { return storage()->GetAsKey(); }
void* PointerAttribute::data() const { return storage()->data(); }

std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey();
Type TypeAttribute::data() const { return storage()->data(); }

bool StrAttribute::operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
std::string StrAttribute::AsString() const { return storage()->AsString(); }

size_t StrAttribute::size() const { return storage()->size(); }

StrAttribute StrAttribute::get(ir::IrContext* ctx, const std::string& value) {
return AttributeManager::get<StrAttribute>(ctx, value);
}

std::vector<Attribute> ArrayAttribute::AsVector() const {
return storage()->AsVector();
}

size_t ArrayAttribute::size() const { return storage()->size(); }

bool ArrayAttribute::empty() const { return storage()->empty(); }

Attribute ArrayAttribute::at(size_t index) const {
return storage()->at(index);
}

void* PointerAttribute::data() const { return storage()->GetAsKey(); }
ArrayAttribute ArrayAttribute::get(IrContext* ctx,
const std::vector<Attribute>& value) {
return AttributeManager::get<ArrayAttribute>(ctx, value);
}

Type TypeAttribute::data() const { return storage()->GetAsKey(); }
ArrayAttributeStorage::ArrayAttributeStorage(const ParamKey& key)
: size_(key.size()) {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
data_ = static_cast<Attribute*>(
::operator new(size_ * sizeof(Attribute), std::align_val_t(align)));
} else {
data_ = static_cast<Attribute*>(::operator new(size_ * sizeof(Attribute)));
}
memcpy(data_, key.data(), sizeof(Attribute) * size_);
}

ArrayAttributeStorage::~ArrayAttributeStorage() {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
::operator delete(data_, std::align_val_t(align));
} else {
::operator delete(data_);
}
}

} // namespace ir

Expand Down
Loading