diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc index 235d545dc331f2..56282996b9e26a 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc @@ -53,9 +53,7 @@ std::vector CollectInputTensor( std::vector* func_args, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { std::vector tensors; - for (auto& operand : op->operands()) { - CHECK(operand); - auto in_value = operand.source(); + for (auto in_value : op->operands_source()) { VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value); // NOTE(Aurelius84): Need always to create placeholder for input tensor. ir::Tensor tensor = details::GetTensor(in_value); @@ -72,7 +70,7 @@ std::vector CollectInputTensor( return tensors; } -void CollectOutputInfo(const ::pir::Operation* op, +void CollectOutputInfo(::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes) { auto op_results = op->results(); @@ -359,7 +357,7 @@ std::vector OpLowererImpl::LowerOps( std::vector OpLowererImpl::DoOpLower( std::shared_ptr op_impl, - const ::pir::Operation* op, + ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, std::vector* op_func_arg_tensors) { VLOG(4) << "Do lower with Compute, op: " << op->name(); diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h index 81e36d8bb7b3b8..705c1f6f8c12d7 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h @@ -131,7 +131,7 @@ class OpLowererImpl : public OpLowererImplBase { */ std::vector DoOpLower( std::shared_ptr op_impl, - const ::pir::Operation* op, + ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, std::vector* op_func_arg_tensors); diff --git a/paddle/cinn/hlir/framework/new_ir/utils.cc b/paddle/cinn/hlir/framework/new_ir/utils.cc index b027992af8c472..3f938981390fbc 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.cc +++ b/paddle/cinn/hlir/framework/new_ir/utils.cc @@ -74,8 +74,7 @@ std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, return names; } -std::vector CompatibleInfo::OutputNames( - const ::pir::Operation& op) { +std::vector CompatibleInfo::OutputNames(::pir::Operation& op) { std::vector names; for (int i = 0; i < op.num_results(); ++i) { auto value = op.result(i); diff --git a/paddle/cinn/hlir/framework/new_ir/utils.h b/paddle/cinn/hlir/framework/new_ir/utils.h index 2a70cd9eedc17a..953dc6672bc18f 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.h +++ b/paddle/cinn/hlir/framework/new_ir/utils.h @@ -40,7 +40,7 @@ struct CompatibleInfo { static std::vector InputNames(const ::pir::Operation& op, bool allow_duplicate = false); - static std::vector OutputNames(const ::pir::Operation& op); + static std::vector OutputNames(::pir::Operation& op); // NOLINT }; } // namespace newir diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 35ea8fb93ecf2f..71692a3e2cc9e0 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -236,10 +236,7 @@ void BindOperation(py::module *m) { )DOC"); op.def("name", &Operation::name) .def("get_parent_block", - py::overload_cast<>(&Operation::GetParent), - return_value_policy::reference) - .def("get_parent_block", - py::overload_cast<>(&Operation::GetParent, py::const_), + &Operation::GetParent, return_value_policy::reference) .def("num_operands", &Operation::num_operands) .def("num_results", &Operation::num_results) diff --git a/paddle/pir/core/block.cc b/paddle/pir/core/block.cc index 5561ea345b6883..d78a0d7b38f5dc 100644 --- a/paddle/pir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -90,7 +90,7 @@ void Block::AddArgument(Type type) { bool Block::TopoOrderCheck(const OpListType &op_list) { std::unordered_set visited_values; - for (const Operation *op : op_list) { + for (Operation *op : op_list) { if (op->num_operands() > 0) { for (size_t i = 0; i < op->num_operands(); ++i) { auto operand = op->operand_source(i); diff --git a/paddle/pir/core/block_argument.cc b/paddle/pir/core/block_argument.cc index 4d05cc54b279eb..4acbfe9176ef91 100644 --- a/paddle/pir/core/block_argument.cc +++ b/paddle/pir/core/block_argument.cc @@ -29,12 +29,12 @@ namespace detail { class BlockArgumentImpl : public ValueImpl { public: static bool classof(const ValueImpl &value) { - return value.kind() == BLOCK_ARGUMENT_INDEX; + return value.kind() == BLOCK_ARG_IDX; } private: BlockArgumentImpl(Type type, Block *owner, uint32_t index) - : ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {} + : ValueImpl(type, BLOCK_ARG_IDX), owner_(owner), index_(index) {} ~BlockArgumentImpl(); // access construction and owner diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 7fa8e076ad1471..68a0eb99bc5989 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -139,7 +139,7 @@ void IrPrinter::PrintOperation(Operation* op) { PrintGeneralOperation(op); } -void IrPrinter::PrintGeneralOperation(const Operation* op) { +void IrPrinter::PrintGeneralOperation(Operation* op) { // TODO(lyk): add API to get opresults directly PrintOpResult(op); os << " ="; @@ -160,7 +160,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) { PrintOpReturnType(op); } -void IrPrinter::PrintFullOperation(const Operation* op) { +void IrPrinter::PrintFullOperation(Operation* op) { PrintGeneralOperation(op); if (op->num_regions() > 0) { os << newline; @@ -186,7 +186,7 @@ void IrPrinter::PrintBlock(const Block* block) { os << "}\n"; } -void IrPrinter::PrintValue(const Value& v) { +void IrPrinter::PrintValue(Value v) { if (!v) { os << "<>"; return; @@ -204,7 +204,7 @@ void IrPrinter::PrintValue(const Value& v) { os << new_name; } -void IrPrinter::PrintOpResult(const Operation* op) { +void IrPrinter::PrintOpResult(Operation* op) { os << " ("; auto num_op_result = op->num_results(); std::vector op_results; @@ -220,7 +220,7 @@ void IrPrinter::PrintOpResult(const Operation* op) { os << ")"; } -void IrPrinter::PrintAttributeMap(const Operation* op) { +void IrPrinter::PrintAttributeMap(Operation* op) { AttributeMap attributes = op->attributes(); std::map> order_attributes( attributes.begin(), attributes.end()); @@ -239,7 +239,7 @@ void IrPrinter::PrintAttributeMap(const Operation* op) { os << "}"; } -void IrPrinter::PrintOpOperands(const Operation* op) { +void IrPrinter::PrintOpOperands(Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); std::vector op_operands; @@ -255,7 +255,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) { os << ")"; } -void IrPrinter::PrintOperandsType(const Operation* op) { +void IrPrinter::PrintOperandsType(Operation* op) { auto num_op_operands = op->num_operands(); std::vector op_operand_types; op_operand_types.reserve(num_op_operands); @@ -276,7 +276,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) { os << ")"; } -void IrPrinter::PrintOpReturnType(const Operation* op) { +void IrPrinter::PrintOpReturnType(Operation* op) { auto num_op_result = op->num_results(); std::vector op_result_types; op_result_types.reserve(num_op_result); diff --git a/paddle/pir/core/ir_printer.h b/paddle/pir/core/ir_printer.h index a845bec52490c8..929da4fe332e1c 100644 --- a/paddle/pir/core/ir_printer.h +++ b/paddle/pir/core/ir_printer.h @@ -51,24 +51,24 @@ class IR_API IrPrinter : public BasicIrPrinter { /// @brief dispatch to custom printer function or PrintGeneralOperation void PrintOperation(Operation* op); /// @brief print operation itself without its regions - void PrintGeneralOperation(const Operation* op); + void PrintGeneralOperation(Operation* op); /// @brief print operation and its regions - void PrintFullOperation(const Operation* op); + void PrintFullOperation(Operation* op); void PrintRegion(const Region& Region); void PrintBlock(const Block* block); - void PrintValue(const Value& v); + void PrintValue(Value v); - void PrintOpResult(const Operation* op); + void PrintOpResult(Operation* op); - void PrintAttributeMap(const Operation* op); + void PrintAttributeMap(Operation* op); - void PrintOpOperands(const Operation* op); + void PrintOpOperands(Operation* op); - void PrintOperandsType(const Operation* op); + void PrintOperandsType(Operation* op); - void PrintOpReturnType(const Operation* op); + void PrintOpReturnType(Operation* op); private: size_t cur_var_number_{0}; diff --git a/paddle/pir/core/op_result.cc b/paddle/pir/core/op_result.cc index 55db307d314338..d14a3c830c8d22 100644 --- a/paddle/pir/core/op_result.cc +++ b/paddle/pir/core/op_result.cc @@ -47,6 +47,6 @@ bool OpResult::operator==(const OpResult &other) const { return impl_ == other.impl_; } -OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {} +OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {} } // namespace pir diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index cadc9c141f13c5..8860473fe33395 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -35,7 +35,7 @@ class IR_API OpResult : public Value { private: friend Operation; - OpResult(const detail::OpResultImpl *impl); // NOLINT + OpResult(detail::OpResultImpl *impl); // NOLINT // Access classof annd dyn_cast_from. friend Value; static bool classof(Value value); diff --git a/paddle/pir/core/op_result_impl.cc b/paddle/pir/core/op_result_impl.cc index 699e8d2e3fb021..d731de937bd5d3 100644 --- a/paddle/pir/core/op_result_impl.cc +++ b/paddle/pir/core/op_result_impl.cc @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_result_impl.h" - -#include +#include "paddle/pir/core/operation.h" namespace pir { namespace detail { @@ -22,31 +21,31 @@ uint32_t OpResultImpl::index() const { if (const auto *outline_result = dyn_cast(this)) { return outline_result->index(); } - return dyn_cast(this)->index(); + return static_cast(this)->index(); } -OpResultImpl::~OpResultImpl() { assert(use_empty()); } +OpResultImpl::~OpResultImpl() { + if (!use_empty()) { + LOG(FATAL) << "Destoryed a op_result that is still in use. \n" + << "The owner op type is:" << owner()->name(); + } +} -Operation *OpResultImpl::owner() const { +Operation *OpResultImpl::owner() { // For inline result, pointer offset index to obtain the address of op. - if (const auto *result = dyn_cast(this)) { + if (auto *result = dyn_cast(this)) { result += result->index() + 1; - return reinterpret_cast( - const_cast(result)); + return reinterpret_cast(result); } // For outline result, pointer offset outline_index to obtain the address of // maximum inline result. - const OpOutlineResultImpl *outline_result = - (const OpOutlineResultImpl *)(this); - outline_result += - (outline_result->outline_index_ - GetMaxInlineResultIndex()); + auto *outline_result = static_cast(this); + outline_result += (outline_result->index() - MAX_INLINE_RESULT_IDX); // The offset of the maximum inline result distance op is // GetMaxInlineResultIndex. - const auto *inline_result = - reinterpret_cast(outline_result); - inline_result += (GetMaxInlineResultIndex() + 1); - return reinterpret_cast( - const_cast(inline_result)); + auto *inline_result = reinterpret_cast(outline_result); + inline_result += OUTLINE_RESULT_IDX; + return reinterpret_cast(inline_result); } } // namespace detail diff --git a/paddle/pir/core/op_result_impl.h b/paddle/pir/core/op_result_impl.h index 99c5573fd30cd7..8183eb9ef02838 100644 --- a/paddle/pir/core/op_result_impl.h +++ b/paddle/pir/core/op_result_impl.h @@ -26,27 +26,20 @@ class OpResultImpl : public ValueImpl { using ValueImpl::ValueImpl; static bool classof(const ValueImpl &value) { - return value.kind() <= OUTLINE_OP_RESULT_INDEX; + return value.kind() <= OUTLINE_RESULT_IDX; } /// /// \brief Get the parent operation of this result.(op_ptr = value_ptr + /// index) /// - Operation *owner() const; + Operation *owner(); /// /// \brief Get the result index of the operation result. /// uint32_t index() const; - /// - /// \brief Get the maximum number of results that can be stored inline. - /// - static uint32_t GetMaxInlineResultIndex() { - return OUTLINE_OP_RESULT_INDEX - 1; - } - ~OpResultImpl(); }; @@ -58,13 +51,13 @@ class OpInlineResultImpl : public OpResultImpl { public: OpInlineResultImpl(Type type, uint32_t result_index) : OpResultImpl(type, result_index) { - if (result_index > GetMaxInlineResultIndex()) { + if (result_index > MAX_INLINE_RESULT_IDX) { throw("Inline result index should not exceed MaxInlineResultIndex(5)"); } } static bool classof(const ValueImpl &value) { - return value.kind() < OUTLINE_OP_RESULT_INDEX; + return value.kind() < OUTLINE_RESULT_IDX; } uint32_t index() const { return kind(); } @@ -77,15 +70,15 @@ class OpInlineResultImpl : public OpResultImpl { class OpOutlineResultImpl : public OpResultImpl { public: OpOutlineResultImpl(Type type, uint32_t outline_index) - : OpResultImpl(type, OUTLINE_OP_RESULT_INDEX), - outline_index_(outline_index) {} + : OpResultImpl(type, OUTLINE_RESULT_IDX), outline_index_(outline_index) {} static bool classof(const ValueImpl &value) { - return value.kind() == OUTLINE_OP_RESULT_INDEX; + return value.kind() == OUTLINE_RESULT_IDX; } uint32_t index() const { return outline_index_; } + private: uint32_t outline_index_; }; diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 15f5a3cfbd523d..b6400ea638681e 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -26,7 +26,12 @@ #include "paddle/pir/core/utils.h" namespace pir { -Operation *Operation::Create(OperationArgument &&argument) { +using detail::OpInlineResultImpl; +using detail::OpOperandImpl; +using detail::OpOutlineResultImpl; +using detail::OpResultImpl; + +Operation *Operation::Create(const OperationArgument &argument) { std::vector inputs; for (auto op_result : argument.inputs) { inputs.emplace_back(op_result); @@ -53,8 +58,7 @@ Operation *Operation::Create(const std::vector &inputs, uint32_t num_results = output_types.size(); uint32_t num_operands = inputs.size(); uint32_t num_successors = successors.size(); - uint32_t max_inline_result_num = - detail::OpResultImpl::GetMaxInlineResultIndex() + 1; + uint32_t max_inline_result_num = MAX_INLINE_RESULT_IDX + 1; size_t result_mem_size = num_results > max_inline_result_num ? sizeof(detail::OpOutlineResultImpl) * @@ -163,13 +167,11 @@ void Operation::Destroy() { } // 5. Free memory. - uint32_t max_inline_result_num = - detail::OpResultImpl::GetMaxInlineResultIndex() + 1; size_t result_mem_size = - num_results_ > max_inline_result_num + num_results_ > OUTLINE_RESULT_IDX ? sizeof(detail::OpOutlineResultImpl) * - (num_results_ - max_inline_result_num) + - sizeof(detail::OpInlineResultImpl) * max_inline_result_num + (num_results_ - OUTLINE_RESULT_IDX) + + sizeof(detail::OpInlineResultImpl) * OUTLINE_RESULT_IDX : sizeof(detail::OpInlineResultImpl) * num_results_; void *aligned_ptr = reinterpret_cast(this) - result_mem_size; @@ -195,67 +197,43 @@ Operation::Operation(const AttributeMap &attributes, num_regions_(num_regions), num_successors_(num_successors) {} -pir::OpResult Operation::result(uint32_t index) const { - if (index >= num_results_) { - IR_THROW("index exceeds OP output range."); - } - uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex(); - const char *ptr = - (index > max_inline_idx) - ? reinterpret_cast(this) - - (max_inline_idx + 1) * sizeof(detail::OpInlineResultImpl) - - (index - max_inline_idx) * sizeof(detail::OpOutlineResultImpl) - : reinterpret_cast(this) - - (index + 1) * sizeof(detail::OpInlineResultImpl); - if (index > max_inline_idx) { - return pir::OpResult( - reinterpret_cast(ptr)); - } else { - return pir::OpResult( - reinterpret_cast(ptr)); +/// +/// \brief op ouput related public interfaces implementation +/// +std::vector Operation::results() { + std::vector res; + for (uint32_t i = 0; i < num_results(); ++i) { + res.push_back(result(i)); } + return res; } -OpOperand Operation::operand(uint32_t index) const { - if (index >= num_operands_) { - IR_THROW("index exceeds OP input range."); +/// +/// \brief op input related public interfaces +/// +std::vector Operation::operands() { + std::vector res; + for (uint32_t i = 0; i < num_operands(); ++i) { + res.push_back(operand(i)); } - const char *ptr = reinterpret_cast(this) + sizeof(Operation) + - (index) * sizeof(detail::OpOperandImpl); - return OpOperand(reinterpret_cast(ptr)); + return res; } - Value Operation::operand_source(uint32_t index) const { - OpOperand val = operand(index); - return val ? val.source() : Value(); + auto val = op_operand_impl(index); + return val ? val->source() : nullptr; } -std::string Operation::name() const { - auto p_name = info_.name(); - return p_name ? p_name : ""; -} - -Attribute Operation::attribute(const std::string &key) const { - IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key); - return attributes_.at(key); -} - -Region *Operation::GetParentRegion() { - return parent_ ? parent_->GetParent() : nullptr; -} - -Operation *Operation::GetParentOp() const { - return parent_ ? parent_->GetParentOp() : nullptr; -} - -const Program *Operation::GetParentProgram() const { - Operation *op = const_cast(this); - while (Operation *parent_op = op->GetParentOp()) { - op = parent_op; +std::vector Operation::operands_source() const { + std::vector res; + for (uint32_t i = 0; i < num_operands(); ++i) { + res.push_back(operand_source(i)); } - ModuleOp module_op = op->dyn_cast(); - return module_op ? module_op.program() : nullptr; + return res; } + +/// +/// \brief op successor related public interfaces +/// BlockOperand Operation::block_operand(uint32_t index) const { IR_ENFORCE(index < num_successors_, "Invalid block_operand index"); return block_operands_ + index; @@ -263,27 +241,49 @@ BlockOperand Operation::block_operand(uint32_t index) const { Block *Operation::successor(uint32_t index) const { return block_operand(index).source(); } - void Operation::set_successor(Block *block, unsigned index) { IR_ENFORCE(index < num_operands_, "Invalid block_operand index"); (block_operands_ + index)->set_source(block); } +/// +/// \brief region related public interfaces implementation +/// Region &Operation::region(unsigned index) { IR_ENFORCE(index < num_regions_, "invalid region index"); return regions_[index]; } - const Region &Operation::region(unsigned index) const { IR_ENFORCE(index < num_regions_, "invalid region index"); return regions_[index]; } +/// +/// \brief parent related public interfaces implementation +/// +Region *Operation::GetParentRegion() const { + return parent_ ? parent_->GetParent() : nullptr; +} +Operation *Operation::GetParentOp() const { + return parent_ ? parent_->GetParentOp() : nullptr; +} +Program *Operation::GetParentProgram() { + auto op = this; + while (Operation *parent_op = op->GetParentOp()) { + op = parent_op; + } + ModuleOp module_op = op->dyn_cast(); + return module_op ? module_op.program() : nullptr; +} void Operation::SetParent(Block *parent, const Block::Iterator &position) { parent_ = parent; position_ = position; } +std::string Operation::name() const { + auto p_name = info_.name(); + return p_name ? p_name : ""; +} void Operation::ReplaceAllUsesWith(const std::vector &values) { IR_ENFORCE(num_results_ == values.size(), "the num of result should be the same."); @@ -306,20 +306,39 @@ void Operation::Verify() { } } -std::vector Operation::operands() const { - std::vector res; - for (uint32_t i = 0; i < num_operands(); ++i) { - res.push_back(operand(i)); +int32_t Operation::ComputeOpResultOffset(uint32_t index) const { + if (index >= num_results_) { + LOG(FATAL) << "index exceeds OP op result range."; } - return res; + if (index < OUTLINE_RESULT_IDX) { + return -static_cast((index + 1u) * sizeof(OpInlineResultImpl)); + } + constexpr uint32_t anchor = OUTLINE_RESULT_IDX * sizeof(OpInlineResultImpl); + index = index - MAX_INLINE_RESULT_IDX; + return -static_cast(index * sizeof(OpOutlineResultImpl) + anchor); } -std::vector Operation::results() const { - std::vector res; - for (uint32_t i = 0; i < num_results(); ++i) { - res.push_back(result(i)); +int32_t Operation::ComputeOpOperandOffset(uint32_t index) const { + if (index >= num_operands_) { + LOG(FATAL) << "index exceeds OP op operand range."; } - return res; + return static_cast(index * sizeof(OpOperandImpl) + + sizeof(Operation)); } +#define COMPONENT_IMPL(component_lower, componnent_upper) \ + componnent_upper##Impl *Operation::component_lower##_impl(uint32_t index) { \ + int32_t offset = Compute##componnent_upper##Offset(index); \ + return reinterpret_cast( \ + reinterpret_cast(this) + offset); \ + } \ + const componnent_upper##Impl *Operation::component_lower##_impl( \ + uint32_t index) const { \ + int32_t offset = Compute##componnent_upper##Offset(index); \ + return reinterpret_cast( \ + reinterpret_cast(this) + offset); \ + } + +COMPONENT_IMPL(op_result, OpResult) +COMPONENT_IMPL(op_operand, OpOperand) } // namespace pir diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index d5821084ba7941..ec2bf586a06eaa 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -29,6 +29,11 @@ class Program; class OpOperand; class OpResult; +namespace detail { +class OpResultImpl; +class OpOperendImpl; +} // namespace detail + class IR_API alignas(8) Operation final { public: /// @@ -43,8 +48,7 @@ class IR_API alignas(8) Operation final { pir::OpInfo op_info, size_t num_regions = 0, const std::vector &successors = {}); - static Operation *Create(OperationArgument &&op_argument); - + static Operation *Create(const OperationArgument &op_argument); /// /// \brief Destroy the operation objects and free memory by create(). /// @@ -54,50 +58,70 @@ class IR_API alignas(8) Operation final { Dialect *dialect() const; - OpResult result(uint32_t index) const; - - OpOperand operand(uint32_t index) const; - - Value operand_source(uint32_t index) const; - - uint32_t num_successors() const { return num_successors_; } - BlockOperand block_operand(uint32_t index) const; - Block *successor(uint32_t index) const; - void set_successor(Block *block, unsigned index); - bool HasSuccessors() { return num_successors_ != 0; } - - /// Returns the region held by this operation at position 'index'. - Region ®ion(unsigned index); - const Region ®ion(unsigned index) const; - uint32_t num_regions() const { return num_regions_; } - - void Print(std::ostream &os); - + /// + /// \brief op attribute related public interfaces + /// + Attribute attribute(const std::string &key) const { + return attributes_.at(key); + } const AttributeMap &attributes() const { return attributes_; } - template T attribute(const std::string &name) { Attribute attr = attribute(name); IR_ENFORCE(attr.isa(), "Attribute (%s) type is not right.", name); return attr.dyn_cast(); } - void set_attribute(const std::string &key, Attribute value) { attributes_[key] = value; } - - Attribute attribute(const std::string &key) const; - bool HasAttribute(const std::string &key) const { return attributes_.find(key) != attributes_.end(); } - pir::OpInfo info() const { return info_; } - + /// + /// \brief op ouput related public interfaces + /// uint32_t num_results() const { return num_results_; } + OpResult result(uint32_t index) { return op_result_impl(index); } + std::vector results(); + /// + /// \brief op input related public interfaces + /// uint32_t num_operands() const { return num_operands_; } + OpOperand operand(uint32_t index) { return op_operand_impl(index); } + std::vector operands(); + Value operand_source(uint32_t index) const; + std::vector operands_source() const; + /// + /// \brief op successor related public interfaces + /// + uint32_t num_successors() const { return num_successors_; } + BlockOperand block_operand(uint32_t index) const; + Block *successor(uint32_t index) const; + void set_successor(Block *block, unsigned index); + bool HasSuccessors() { return num_successors_ != 0; } + + /// + /// \brief region related public interfaces + /// + uint32_t num_regions() const { return num_regions_; } + Region ®ion(unsigned index); + const Region ®ion(unsigned index) const; + + /// + /// \brief parent related public interfaces + /// + Block *GetParent() const { return parent_; } + Region *GetParentRegion() const; + Operation *GetParentOp() const; + Program *GetParentProgram(); + operator Block::Iterator() { return position_; } + operator Block::ConstIterator() const { return position_; } + + void Print(std::ostream &os); + pir::OpInfo info() const { return info_; } std::string name() const; template @@ -120,28 +144,6 @@ class IR_API alignas(8) Operation final { return info_.HasInterface(); } - const Block *GetParent() const { return parent_; } - - Block *GetParent() { - return const_cast( - const_cast(this)->GetParent()); - } - - Region *GetParentRegion(); - - Operation *GetParentOp() const; - - const Program *GetParentProgram() const; - - Program *GetParentProgram() { - return const_cast( - const_cast(this)->GetParentProgram()); - } - - operator Block::Iterator() { return position_; } - - operator Block::ConstIterator() const { return position_; } - /// Replace all uses of results of this operation with the provided 'values'. void ReplaceAllUsesWith(const std::vector &values); @@ -153,10 +155,6 @@ class IR_API alignas(8) Operation final { void Verify(); - std::vector operands() const; - - std::vector results() const; - private: DISABLE_COPY_AND_ASSIGN(Operation); Operation(const AttributeMap &attribute, @@ -166,6 +164,14 @@ class IR_API alignas(8) Operation final { uint32_t num_regions, uint32_t num_successors); + int32_t ComputeOpResultOffset(uint32_t index) const; + detail::OpResultImpl *op_result_impl(uint32_t index); + const detail::OpResultImpl *op_result_impl(uint32_t index) const; + + int32_t ComputeOpOperandOffset(uint32_t index) const; + detail::OpOperandImpl *op_operand_impl(uint32_t index); + const detail::OpOperandImpl *op_operand_impl(uint32_t index) const; + template struct CastUtil { static T call(Operation *op) { diff --git a/paddle/pir/core/parser/ir_parser.cc b/paddle/pir/core/parser/ir_parser.cc index 57eea42ae12b9d..960ba9fd49610f 100644 --- a/paddle/pir/core/parser/ir_parser.cc +++ b/paddle/pir/core/parser/ir_parser.cc @@ -207,11 +207,11 @@ void IrParser::ParseBlock(Block& block) { // NOLINT ConsumeAToken("}"); } -// Operation := OpResultList ":=" Opname "(" OprandList ? ")" AttributeMap ":" +// Operation := ValueList ":=" Opname "(" OprandList ? ")" AttributeMap ":" // FunctionType // FunctionType := "(" TypeList ")" "->" TypeList Operation* IrParser::ParseOperation() { - std::vector opresultindex = ParseOpResultList(); + std::vector value_index = ParseValueList(); ConsumeAToken("="); OpInfo opinfo = ParseOpInfo(); @@ -232,31 +232,30 @@ Operation* IrParser::ParseOperation() { Operation::Create(inputs, attributeMap, type_vector, opinfo, 0); for (uint32_t i = 0; i < op->num_results(); i++) { - std::string key_t = opresultindex[i]; - opresultmap[key_t] = op->result(i); + std::string key_t = value_index[i]; + value_map[key_t] = op->result(i); } return op; } -// OpResultList := ValueList // ValueList := ValueId(,ValueId)* -std::vector IrParser::ParseOpResultList() { - std::vector opresultindex{}; +std::vector IrParser::ParseValueList() { + std::vector value_index{}; ConsumeAToken("("); Token index_token = ConsumeToken(); while (index_token.val_ != ")") { if (index_token.token_type_ == NULL_) { - opresultindex.push_back("null"); + value_index.push_back("null"); } else { std::string str = index_token.val_; - opresultindex.push_back(str); + value_index.push_back(str); } if (ConsumeToken().val_ == ")") break; index_token = ConsumeToken(); } - return opresultindex; + return value_index; } // OpName := "\"" StringIdentifer "." StringIdentifer "\"" @@ -279,7 +278,7 @@ std::vector IrParser::ParseOprandList() { inputs.emplace_back(); } else { t = ind_token.val_; - inputs.push_back(opresultmap[t]); + inputs.push_back(value_map[t]); } Token token = ConsumeToken(); if (token.val_ == ")") { diff --git a/paddle/pir/core/parser/ir_parser.h b/paddle/pir/core/parser/ir_parser.h index eee732d3e8bf82..f345e28215f95f 100644 --- a/paddle/pir/core/parser/ir_parser.h +++ b/paddle/pir/core/parser/ir_parser.h @@ -18,7 +18,7 @@ #include "paddle/pir/core/parser/lexer.h" #include "paddle/pir/core/program.h" -using OpResultMap = std::map; +using ValueMap = std::map; using AttributeMap = std::unordered_map; using OpAttributeInfoMap = std::map; @@ -27,7 +27,7 @@ class IrParser { public: std::unique_ptr lexer; IrContext* ctx; - OpResultMap opresultmap; + ValueMap value_map; std::unique_ptr builder; public: @@ -49,7 +49,7 @@ class IrParser { OpInfo ParseOpInfo(); - std::vector ParseOpResultList(); + std::vector ParseValueList(); std::vector ParseOprandList(); diff --git a/paddle/pir/core/value_impl.cc b/paddle/pir/core/value_impl.cc index c17e44ffa3aa5a..999a78e24063d5 100644 --- a/paddle/pir/core/value_impl.cc +++ b/paddle/pir/core/value_impl.cc @@ -41,9 +41,9 @@ std::string ValueImpl::PrintUdChain() { return result.str(); } ValueImpl::ValueImpl(Type type, uint32_t kind) { - if (kind > BLOCK_ARGUMENT_INDEX) { + if (kind > BLOCK_ARG_IDX) { LOG(FATAL) << "The kind of value_impl(" << kind - << "), is bigger than BLOCK_ARGUMENT_INDEX(7)"; + << "), is bigger than BLOCK_ARG_IDX(7)"; } type_ = type; first_use_offseted_by_kind_ = reinterpret_cast( diff --git a/paddle/pir/core/value_impl.h b/paddle/pir/core/value_impl.h index ccd5e835abdeb2..0720360d563bc2 100644 --- a/paddle/pir/core/value_impl.h +++ b/paddle/pir/core/value_impl.h @@ -17,10 +17,11 @@ #include "paddle/pir/core/op_operand_impl.h" #include "paddle/pir/core/value.h" -namespace pir { -constexpr const uint32_t OUTLINE_OP_RESULT_INDEX = 6; -constexpr const uint32_t BLOCK_ARGUMENT_INDEX = OUTLINE_OP_RESULT_INDEX + 1; +#define OUTLINE_RESULT_IDX 6u +#define MAX_INLINE_RESULT_IDX (OUTLINE_RESULT_IDX - 1u) +#define BLOCK_ARG_IDX (OUTLINE_RESULT_IDX + 1u) +namespace pir { class Operation; namespace detail {