From ce3f6a0eab9a41035322807c3255ab5de42e1d41 Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Fri, 14 Jul 2023 07:35:04 +0000 Subject: [PATCH] [IR] finetune the StrAttribute interface. --- .../instruction/phi_kernel_instruction.cc | 6 +- .../interpreter/interpreter_util.cc | 5 +- paddle/fluid/ir/dialect/kernel_op.cc | 4 +- .../ir/dialect/op_generator/op_build_gen.py | 17 +- .../ir/dialect/op_generator/op_verify_gen.py | 2 +- .../phi_kernel_adaptor/phi_kernel_adaptor.h | 5 +- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 14 +- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 14 +- .../transforms/transform_general_functions.cc | 2 +- .../translator/program_translator.cc | 2 +- .../dead_code_elimination_pass.cc | 2 +- paddle/ir/core/builtin_attribute.cc | 64 ++++++-- paddle/ir/core/builtin_attribute.h | 57 +++---- paddle/ir/core/builtin_attribute_storage.h | 145 ++++++++---------- paddle/ir/core/ir_printer.cc | 4 +- .../pattern_rewrite/pattern_rewrite_test.cc | 37 +++-- 16 files changed, 222 insertions(+), 158 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 32c7e265e7ba6b..53cef292d9fa4a 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -47,7 +47,7 @@ OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) { // and so that they would be dispatched to host thread. auto op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data(); + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); if (op_name == kCoalesceTensor && (!platform::is_xpu_place(place) || op->attribute("persist_output").data() == false) && @@ -77,7 +77,7 @@ PhiKernelInstruction::PhiKernelInstruction( : InstructionBase(id, place) { auto op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data(); + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); phi_op_name_ = op_name; @@ -142,7 +142,7 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process infer meta context"; auto kernel_name = - op_attributes.at("kernel_name").dyn_cast().data(); + op_attributes.at("kernel_name").dyn_cast().AsString(); auto kernel_key = op_attributes.at("kernel_key") .dyn_cast() .data(); diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index bd63d20c21510f..34899261dbde70 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -952,7 +952,8 @@ void BuildOpFuncList( OpFuncNode op_func_node; auto attr_map = (*it)->attributes(); - auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data(); + auto op_name = + attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); op_func_node.phi_op_name_ = op_name; if (op_name == "builtin.combine" || op_name == "pd.feed" || @@ -986,7 +987,7 @@ void BuildOpFuncList( &(op_func_node.infer_meta_context_)); auto kernel_name = - attr_map.at("kernel_name").dyn_cast().data(); + attr_map.at("kernel_name").dyn_cast().AsString(); auto kernel_key = attr_map.at("kernel_key") .dyn_cast() .data(); diff --git a/paddle/fluid/ir/dialect/kernel_op.cc b/paddle/fluid/ir/dialect/kernel_op.cc index e46a874045d73c..45dce38f11ba4f 100644 --- a/paddle/fluid/ir/dialect/kernel_op.cc +++ b/paddle/fluid/ir/dialect/kernel_op.cc @@ -45,10 +45,10 @@ void PhiKernelOp::Verify() { } std::string PhiKernelOp::op_name() { - return attributes().at("op_name").dyn_cast().data(); + return attributes().at("op_name").dyn_cast().AsString(); } std::string PhiKernelOp::kernel_name() { - return attributes().at("kernel_name").dyn_cast().data(); + return attributes().at("kernel_name").dyn_cast().AsString(); } phi::KernelKey PhiKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); diff --git a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py index 76c48bdde5e1fe..5c3696d02c88c9 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py @@ -542,11 +542,14 @@ def gen_build_func_str( GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); +""" + GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ {attr_type} {attribute_name}; for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast()[i].dyn_cast<{inner_type}>().data()); + {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ @@ -566,11 +569,15 @@ def gen_build_func_str( # attr_type = "std::vector" if "ir::ArrayAttribute" in op_attribute_type_list[idx]: inner_type = op_attribute_type_list[idx][19:-1] + data_name = "data" + if inner_type == "ir::StrAttribute": + data_name = "AsString" get_attributes_str += ( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( attr_type=attr_type, attribute_name=op_attribute_name_list[idx], inner_type=inner_type, + data_name=data_name, ) ) elif ( @@ -593,6 +600,14 @@ def gen_build_func_str( attribute_name=op_attribute_name_list[idx], ) ) + elif "ir::StrAttribute" in op_attribute_type_list[idx]: + get_attributes_str += ( + GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + attr_type=attr_type, + attribute_name=op_attribute_name_list[idx], + attr_ir_type=op_attribute_type_list[idx], + ) + ) else: get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( attr_type=attr_type, diff --git a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py index 9c1e8a8d1ebf20..f5f0711534f927 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py @@ -78,7 +78,7 @@ PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast()[i].isa<{standard}>(), + PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); }}""" OUTPUT_TYPE_CHECK_TEMPLATE = """ diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h index fd4aecbada17bf..1466a580ff0141 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -76,7 +76,8 @@ class PhiKernelAdaptor { for (auto it = block->begin(); it != block->end(); ++it) { auto attr_map = (*it)->attributes(); - auto op_name = attr_map.at("op_name").dyn_cast().data(); + auto op_name = + attr_map.at("op_name").dyn_cast().AsString(); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name); @@ -104,7 +105,7 @@ class PhiKernelAdaptor { infer_meta_impl->infer_meta_(&ctx); auto kernel_name = - attr_map.at("kernel_name").dyn_cast().data(); + attr_map.at("kernel_name").dyn_cast().AsString(); auto kernel_key = attr_map.at("kernel_key") .dyn_cast() .data(); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index c44a674275f444..a0ce5775735387 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -171,7 +171,7 @@ void HandleForSpecialOp( std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = - op->attributes().at("op_name").dyn_cast().data(); + op->attributes().at("op_name").dyn_cast().AsString(); } if (op_name == "pd.fetch") { @@ -244,7 +244,7 @@ void HandleForSpecialOp( auto param_name = op->attributes() .at("parameter_name") .dyn_cast() - .data(); + .AsString(); auto value = op->operand(0); // change opreand name to param_name @@ -262,7 +262,7 @@ void HandleForSpecialOp( auto param_name = op->attributes() .at("parameter_name") .dyn_cast() - .data(); + .AsString(); auto value = op->result(0); value_2_var_name->emplace(value, param_name); } @@ -306,7 +306,7 @@ void HandleForInplaceOp( std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = - op->attributes().at("op_name").dyn_cast().data(); + op->attributes().at("op_name").dyn_cast().AsString(); } ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); @@ -356,8 +356,10 @@ void BuildScope(const ir::Block& block, std::string op_name = op->name(); if (op->attributes().count("op_name")) { - op_name = - op->attributes().at("op_name").dyn_cast().data(); + op_name = op->attributes() + .at("op_name") + .dyn_cast() + .AsString(); } VLOG(4) << "build op:" << op_name; diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 7f6a804382921a..235f3dc9f353ba 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -188,10 +188,10 @@ void BuildPhiContext( } else if (attr_type_name == "ir::BoolAttribute") { ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "ir::StrAttribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().AsString()); } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().data(); + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( @@ -207,7 +207,7 @@ void BuildPhiContext( } ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().data(); + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( @@ -222,7 +222,7 @@ void BuildPhiContext( } ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().data(); + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { if (array_list[0].isa()) { @@ -238,7 +238,7 @@ void BuildPhiContext( } ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().data(); + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { @@ -255,7 +255,7 @@ void BuildPhiContext( } ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().data(); + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { @@ -286,7 +286,7 @@ void BuildPhiContext( // TODO(phlrain): use var type instead of op name if (op->attributes().count("op_name") && - (op->attributes().at("op_name").dyn_cast().data() == + (op->attributes().at("op_name").dyn_cast().AsString() == "pd.fetch")) { // process fetch op auto fetch_var = inner_scope->FindVar("fetch"); diff --git a/paddle/fluid/ir/transforms/transform_general_functions.cc b/paddle/fluid/ir/transforms/transform_general_functions.cc index 0de36ffd20b215..966e4035fc3b39 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.cc +++ b/paddle/fluid/ir/transforms/transform_general_functions.cc @@ -34,7 +34,7 @@ std::pair GetParameterFromValue(ir::Value value) { std::string name = op->attributes() .at(op.attributes_name[0]) .dyn_cast() - .data(); + .AsString(); ir::Parameter* param = program->GetParameter(name); PADDLE_ENFORCE_NOT_NULL( param, phi::errors::InvalidArgument("Parameter should not be null.")); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 3e2d88c7a4ec9c..2a3e7cf6074144 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -223,7 +223,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( if (defining_op->HasAttribute(kAttrStopGradients)) { stop_gradients = defining_op->attribute(kAttrStopGradients) .dyn_cast() - .data(); + .AsVector(); } else { stop_gradients = std::vector( defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); diff --git a/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc b/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc index c595a7cae033a6..6693214e117aaf 100644 --- a/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc +++ b/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc @@ -56,7 +56,7 @@ class DeadCodeEliminationPass : public ir::Pass { get_parameter_op->attributes() .at(get_parameter_op.attributes_name[0]) .dyn_cast() - .data()); + .AsString()); } block->erase(*op); } diff --git a/paddle/ir/core/builtin_attribute.cc b/paddle/ir/core/builtin_attribute.cc index d6c2b3f829dafc..38ca80cb1f9d77 100644 --- a/paddle/ir/core/builtin_attribute.cc +++ b/paddle/ir/core/builtin_attribute.cc @@ -15,27 +15,69 @@ #include "paddle/ir/core/builtin_attribute.h" namespace ir { -std::string StrAttribute::data() const { return storage()->GetAsKey(); } -uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); } +bool BoolAttribute::data() const { return storage()->data(); } -bool BoolAttribute::data() const { return storage()->GetAsKey(); } +float FloatAttribute::data() const { return storage()->data(); } -float FloatAttribute::data() const { return storage()->GetAsKey(); } +double DoubleAttribute::data() const { return storage()->data(); } -double DoubleAttribute::data() const { return storage()->GetAsKey(); } +int32_t Int32Attribute::data() const { return storage()->data(); } -int32_t Int32Attribute::data() const { return storage()->GetAsKey(); } +int64_t Int64Attribute::data() const { return storage()->data(); } -int64_t Int64Attribute::data() const { return storage()->GetAsKey(); } +void* PointerAttribute::data() const { return storage()->data(); } -std::vector ArrayAttribute::data() const { - return storage()->GetAsKey(); +Type TypeAttribute::data() const { return storage()->data(); } + +bool StrAttribute::operator<(const StrAttribute& right) const { + return storage() < right.storage(); +} +std::string StrAttribute::AsString() const { return storage()->AsString(); } + +size_t StrAttribute::size() const { return storage()->size(); } + +StrAttribute StrAttribute::get(ir::IrContext* ctx, const std::string& value) { + return AttributeManager::get(ctx, value); +} + +std::vector ArrayAttribute::AsVector() const { + return storage()->AsVector(); +} + +size_t ArrayAttribute::size() const { return storage()->size(); } + +bool ArrayAttribute::empty() const { return storage()->empty(); } + +Attribute ArrayAttribute::at(size_t index) const { + return storage()->at(index); } -void* PointerAttribute::data() const { return storage()->GetAsKey(); } +ArrayAttribute ArrayAttribute::get(IrContext* ctx, + const std::vector& value) { + return AttributeManager::get(ctx, value); +} -Type TypeAttribute::data() const { return storage()->GetAsKey(); } +ArrayAttributeStorage::ArrayAttributeStorage(const ParamKey& key) + : size_(key.size()) { + constexpr size_t align = alignof(Attribute); + if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) { + data_ = static_cast( + ::operator new(size_ * sizeof(Attribute), std::align_val_t(align))); + } else { + data_ = static_cast(::operator new(size_ * sizeof(Attribute))); + } + memcpy(data_, key.data(), sizeof(Attribute) * size_); +} + +ArrayAttributeStorage::~ArrayAttributeStorage() { + constexpr size_t align = alignof(Attribute); + if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) { + ::operator delete(data_, std::align_val_t(align)); + } else { + ::operator delete(data_); + } +} } // namespace ir diff --git a/paddle/ir/core/builtin_attribute.h b/paddle/ir/core/builtin_attribute.h index 8d8efbc4d79f7a..3969d962e1f4e4 100644 --- a/paddle/ir/core/builtin_attribute.h +++ b/paddle/ir/core/builtin_attribute.h @@ -19,21 +19,6 @@ #include "paddle/ir/core/utils.h" namespace ir { -class IR_API StrAttribute : public Attribute { - public: - using Attribute::Attribute; - - DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage); - - bool operator<(const StrAttribute& right) const { - return storage() < right.storage(); - } - - std::string data() const; - - uint32_t size() const; -}; - class IR_API BoolAttribute : public Attribute { public: using Attribute::Attribute; @@ -79,37 +64,55 @@ class IR_API Int64Attribute : public Attribute { int64_t data() const; }; -class IR_API ArrayAttribute : public Attribute { +class IR_API PointerAttribute : public Attribute { public: using Attribute::Attribute; - DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage); + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage); - std::vector data() const; + void* data() const; +}; - size_t size() const { return data().size(); } +class IR_API TypeAttribute : public Attribute { + public: + using Attribute::Attribute; - bool empty() const { return data().empty(); } + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage); - Attribute operator[](size_t index) const { return data()[index]; } + Type data() const; }; -class IR_API PointerAttribute : public Attribute { +class IR_API StrAttribute : public Attribute { public: using Attribute::Attribute; - DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage); + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage); - void* data() const; + bool operator<(const StrAttribute& right) const; + + std::string AsString() const; + + size_t size() const; + + static StrAttribute get(IrContext* ctx, const std::string& value); }; -class IR_API TypeAttribute : public Attribute { +class IR_API ArrayAttribute : public Attribute { public: using Attribute::Attribute; - DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage); + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage); - Type data() const; + std::vector AsVector() const; + + size_t size() const; + + bool empty() const; + + Attribute at(size_t index) const; + + static ArrayAttribute get(IrContext* ctx, + const std::vector& value); }; } // namespace ir diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index 891a0691186f59..7050555c1d0374 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -20,32 +20,41 @@ #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute_base.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/type.h" #include "paddle/ir/core/utils.h" namespace ir { -#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ - struct concrete_storage : public ir::AttributeStorage { \ - using ParamKey = base_type; \ - \ - explicit concrete_storage(const ParamKey &key) { data_ = key; } \ - \ - static concrete_storage *Construct(const ParamKey &key) { \ - return new concrete_storage(key); \ - } \ - \ - static std::size_t HashValue(const ParamKey &key) { \ - return std::hash()(key); \ - } \ - \ - bool operator==(const ParamKey &key) const { return data_ == key; } \ - \ - ParamKey GetAsKey() const { return data_; } \ - \ - private: \ - ParamKey data_; \ - }; +#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(ConcreteStorage, BaseType) \ + struct ConcreteStorage : public AttributeStorage { \ + using ParamKey = BaseType; \ + \ + explicit ConcreteStorage(ParamKey key) { data_ = key; } \ + \ + static ConcreteStorage *Construct(ParamKey key) { \ + return new ConcreteStorage(key); \ + } \ + \ + static size_t HashValue(ParamKey key) { \ + return std::hash{}(key); \ + } \ + \ + bool operator==(ParamKey key) const { return data_ == key; } \ + \ + BaseType data() const { return data_; } \ + \ + private: \ + BaseType data_; \ + } + +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(TypeAttributeStorage, Type); /// /// \brief Define Parametric AttributeStorage for StrAttribute. @@ -53,53 +62,53 @@ namespace ir { struct StrAttributeStorage : public AttributeStorage { using ParamKey = std::string; - explicit StrAttributeStorage(const ParamKey &key) { - data_ = reinterpret_cast(malloc(key.size())); - memcpy(data_, key.c_str(), key.size()); - size_ = key.size(); + explicit StrAttributeStorage(const ParamKey &key) : size_(key.size()) { + if (size_ > kLocalSize) { + data_ = static_cast(::operator new(size_)); + memcpy(data_, key.c_str(), size_); + } else { + memcpy(buff_, key.c_str(), size_); + } } - ~StrAttributeStorage() { free(data_); } + ~StrAttributeStorage() { + if (size_ > kLocalSize) ::operator delete(data_); + } static StrAttributeStorage *Construct(const ParamKey &key) { return new StrAttributeStorage(key); } - static std::size_t HashValue(const ParamKey &key) { - return std::hash()(key); + static size_t HashValue(const ParamKey &key) { + return std::hash{}(key); } bool operator==(const ParamKey &key) const { - return std::equal(data_, data_ + size_, key.c_str()); + if (size_ != key.size()) return false; + const char *data = size_ > kLocalSize ? data_ : buff_; + return std::equal(data, data + size_, key.c_str()); } - ParamKey GetAsKey() const { return ParamKey(data_, size_); } + // Note: The const char* is not end with '\0'. + const char *data() const { return size_ > kLocalSize ? data_ : buff_; } + size_t size() const { return size_; } + std::string AsString() const { return std::string(data(), size_); } private: - char *data_; - uint32_t size_; + static constexpr size_t kLocalSize = sizeof(void *) / sizeof(char); + union { + char *data_; + char buff_[kLocalSize]; + }; + const size_t size_; }; -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t); -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t); -DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); - struct ArrayAttributeStorage : public AttributeStorage { using ParamKey = std::vector; - explicit ArrayAttributeStorage(const ParamKey &key) { - data_ = - reinterpret_cast(malloc(sizeof(Attribute) * key.size())); - memcpy(reinterpret_cast(data_), - reinterpret_cast(key.data()), - sizeof(Attribute) * key.size()); - length_ = key.size(); - } + explicit ArrayAttributeStorage(const ParamKey &key); - ~ArrayAttributeStorage() { free(reinterpret_cast(data_)); } + ~ArrayAttributeStorage(); static ArrayAttributeStorage *Construct(const ParamKey &key) { return new ArrayAttributeStorage(key); @@ -114,43 +123,25 @@ struct ArrayAttributeStorage : public AttributeStorage { } bool operator==(const ParamKey &key) const { - if (key.size() != length_) { - return false; - } - for (size_t i = 0; i < length_; ++i) { - if (data_[i] != key[i]) { - return false; - } - } - return true; + return key.size() == size_ && std::equal(key.begin(), key.end(), data_); } - ParamKey GetAsKey() const { return ParamKey(data_, data_ + length_); } - - private: - Attribute *data_ = nullptr; - size_t length_ = 0; -}; - -struct TypeAttributeStorage : public AttributeStorage { - using ParamKey = Type; - - explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {} - - static TypeAttributeStorage *Construct(ParamKey key) { - return new TypeAttributeStorage(key); + std::vector AsVector() const { + return std::vector(data_, data_ + size_); } - static std::size_t HashValue(const ParamKey &key) { - return std::hash()(key); - } + size_t size() const { return size_; } - bool operator==(const ParamKey &key) const { return value_ == key; } + bool empty() const { return size_ == 0u; } - ParamKey GetAsKey() const { return value_; } + Attribute at(size_t index) const { + IR_ENFORCE(index < size_, "Invalid index"); + return data_[index]; + } private: - Type value_; + Attribute *data_; + const size_t size_; }; } // namespace ir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 7a9642bd042861..a322e8fca9ffd6 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -85,7 +85,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { } if (auto s = attr.dyn_cast()) { - os << s.data(); + os << s.AsString(); } else if (auto b = attr.dyn_cast()) { os << b.data(); } else if (auto f = attr.dyn_cast()) { @@ -99,7 +99,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { } else if (auto p = attr.dyn_cast()) { os << p.data(); } else if (auto arr = attr.dyn_cast()) { - const auto& vec = arr.data(); + const auto& vec = arr.AsVector(); os << "array["; PrintInterleave( vec.begin(), diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 1edfdead8e769e..ebb6144753e2aa 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -232,7 +232,7 @@ class RedundantTransposeFusePattern private: std::vector GetAxis(paddle::dialect::TransposeOp op) const { - auto array_attr = op.attribute("perm").data(); + auto array_attr = op.attribute("perm").AsVector(); std::vector axis(array_attr.size()); for (size_t i = 0; i < array_attr.size(); ++i) { axis[i] = array_attr[i].dyn_cast().data(); @@ -333,7 +333,7 @@ class Conv2dBnFusePattern phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out()); std::vector new_bias_new_shape(new_conv2d_out_shape.size(), 1); std::string data_format = - new_conv2d_op.attribute("data_format").data(); + new_conv2d_op.attribute("data_format").AsString(); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); new_bias_new_shape[1] = new_conv2d_out_shape[1]; paddle::dialect::ReshapeOp reshape_bias_op = @@ -503,7 +503,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, i < attributes.at("strides").dyn_cast().size(); i++) { strides.push_back(attributes.at("strides") - .dyn_cast()[i] + .dyn_cast() + .at(i) .dyn_cast() .data()); } @@ -513,27 +514,30 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, i < attributes.at("paddings_t").dyn_cast().size(); i++) { paddings_t.push_back(attributes.at("paddings_t") - .dyn_cast()[i] + .dyn_cast() + .at(i) .dyn_cast() .data()); } - std::string padding_algorithm = - attributes.at("padding_algorithm").dyn_cast().data(); + std::string padding_algorithm = attributes.at("padding_algorithm") + .dyn_cast() + .AsString(); std::vector dilations_t; for (size_t i = 0; i < attributes.at("dilations_t").dyn_cast().size(); i++) { dilations_t.push_back(attributes.at("dilations_t") - .dyn_cast()[i] + .dyn_cast() + .at(i) .dyn_cast() .data()); } int groups = attributes.at("groups").dyn_cast().data(); std::string data_format = - attributes.at("data_format").dyn_cast().data(); + attributes.at("data_format").dyn_cast().AsString(); std::string activation = - attributes.at("activation").dyn_cast().data(); + attributes.at("activation").dyn_cast().AsString(); bool exhaustive_search = attributes.at("exhaustive_search").dyn_cast().data(); std::vector channels; @@ -541,7 +545,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, i < attributes.at("channels").dyn_cast().size(); i++) { channels.push_back(attributes.at("channels") - .dyn_cast()[i] + .dyn_cast() + .at(i) .dyn_cast() .data()); } @@ -776,7 +781,8 @@ void Conv2dFusionOpTest::Verify() { i < attributes.at("strides").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("strides") - .dyn_cast()[i] + .dyn_cast() + .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: strides is not right.")); @@ -789,7 +795,8 @@ void Conv2dFusionOpTest::Verify() { i < attributes.at("paddings_t").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("paddings_t") - .dyn_cast()[i] + .dyn_cast() + .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: paddings_t is not right.")); @@ -807,7 +814,8 @@ void Conv2dFusionOpTest::Verify() { i < attributes.at("dilations_t").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("dilations_t") - .dyn_cast()[i] + .dyn_cast() + .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: dilations_t is not right.")); @@ -837,7 +845,8 @@ void Conv2dFusionOpTest::Verify() { i < attributes.at("channels").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("channels") - .dyn_cast()[i] + .dyn_cast() + .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: channels is not right."));