Skip to content

[PIR] standardize the use of value[-2]. #55322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ std::vector<ir::Tensor> CollectInputTensor(
std::vector<ir::Tensor>* func_args,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) {
std::vector<ir::Tensor> tensors;
for (auto& operand : op->operands()) {
CHECK(operand);
auto in_value = operand.source();
for (auto in_value : op->operands_source()) {
VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value);
// NOTE(Aurelius84): Need always to create placeholder for input tensor.
ir::Tensor tensor = details::GetTensor(in_value);
Expand All @@ -72,7 +70,7 @@ std::vector<ir::Tensor> CollectInputTensor(
return tensors;
}

void CollectOutputInfo(const ::pir::Operation* op,
void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes) {
auto op_results = op->results();
Expand Down Expand Up @@ -359,7 +357,7 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(

std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors) {
VLOG(4) << "Do lower with Compute, op: " << op->name();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);

Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/framework/new_ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::pir::Operation& op,
return names;
}

std::vector<std::string> CompatibleInfo::OutputNames(
const ::pir::Operation& op) {
std::vector<std::string> CompatibleInfo::OutputNames(::pir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct CompatibleInfo {
static std::vector<std::string> InputNames(const ::pir::Operation& op,
bool allow_duplicate = false);

static std::vector<std::string> OutputNames(const ::pir::Operation& op);
static std::vector<std::string> OutputNames(::pir::Operation& op); // NOLINT
};

} // namespace newir
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ void BindOperation(py::module *m) {
)DOC");
op.def("name", &Operation::name)
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent),
return_value_policy::reference)
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent, py::const_),
&Operation::GetParent,
return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void Block::AddArgument(Type type) {

bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (const Operation *op : op_list) {
for (Operation *op : op_list) {
if (op->num_operands() > 0) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ namespace detail {
class BlockArgumentImpl : public ValueImpl {
public:
static bool classof(const ValueImpl &value) {
return value.kind() == BLOCK_ARGUMENT_INDEX;
return value.kind() == BLOCK_ARG_IDX;
}

private:
BlockArgumentImpl(Type type, Block *owner, uint32_t index)
: ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {}
: ValueImpl(type, BLOCK_ARG_IDX), owner_(owner), index_(index) {}

~BlockArgumentImpl();
// access construction and owner
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void IrPrinter::PrintOperation(Operation* op) {
PrintGeneralOperation(op);
}

void IrPrinter::PrintGeneralOperation(const Operation* op) {
void IrPrinter::PrintGeneralOperation(Operation* op) {
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
Expand All @@ -160,7 +160,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) {
PrintOpReturnType(op);
}

void IrPrinter::PrintFullOperation(const Operation* op) {
void IrPrinter::PrintFullOperation(Operation* op) {
PrintGeneralOperation(op);
if (op->num_regions() > 0) {
os << newline;
Expand All @@ -186,7 +186,7 @@ void IrPrinter::PrintBlock(const Block* block) {
os << "}\n";
}

void IrPrinter::PrintValue(const Value& v) {
void IrPrinter::PrintValue(Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
Expand All @@ -204,7 +204,7 @@ void IrPrinter::PrintValue(const Value& v) {
os << new_name;
}

void IrPrinter::PrintOpResult(const Operation* op) {
void IrPrinter::PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
Expand All @@ -220,7 +220,7 @@ void IrPrinter::PrintOpResult(const Operation* op) {
os << ")";
}

void IrPrinter::PrintAttributeMap(const Operation* op) {
void IrPrinter::PrintAttributeMap(Operation* op) {
AttributeMap attributes = op->attributes();
std::map<std::string, Attribute, std::less<std::string>> order_attributes(
attributes.begin(), attributes.end());
Expand All @@ -239,7 +239,7 @@ void IrPrinter::PrintAttributeMap(const Operation* op) {
os << "}";
}

void IrPrinter::PrintOpOperands(const Operation* op) {
void IrPrinter::PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
Expand All @@ -255,7 +255,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
os << ")";
}

void IrPrinter::PrintOperandsType(const Operation* op) {
void IrPrinter::PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
Expand All @@ -276,7 +276,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
os << ")";
}

void IrPrinter::PrintOpReturnType(const Operation* op) {
void IrPrinter::PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results();
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,24 @@ class IR_API IrPrinter : public BasicIrPrinter {
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(const Operation* op);
void PrintGeneralOperation(Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(const Operation* op);
void PrintFullOperation(Operation* op);

void PrintRegion(const Region& Region);
void PrintBlock(const Block* block);

void PrintValue(const Value& v);
void PrintValue(Value v);

void PrintOpResult(const Operation* op);
void PrintOpResult(Operation* op);

void PrintAttributeMap(const Operation* op);
void PrintAttributeMap(Operation* op);

void PrintOpOperands(const Operation* op);
void PrintOpOperands(Operation* op);

void PrintOperandsType(const Operation* op);
void PrintOperandsType(Operation* op);

void PrintOpReturnType(const Operation* op);
void PrintOpReturnType(Operation* op);

private:
size_t cur_var_number_{0};
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}
OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
2 changes: 1 addition & 1 deletion paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IR_API OpResult : public Value {

private:
friend Operation;
OpResult(const detail::OpResultImpl *impl); // NOLINT
OpResult(detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
Expand Down
33 changes: 16 additions & 17 deletions paddle/pir/core/op_result_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/core/op_result_impl.h"

#include <cassert>
#include "paddle/pir/core/operation.h"

namespace pir {
namespace detail {
Expand All @@ -22,31 +21,31 @@ uint32_t OpResultImpl::index() const {
if (const auto *outline_result = dyn_cast<OpOutlineResultImpl>(this)) {
return outline_result->index();
}
return dyn_cast<OpInlineResultImpl>(this)->index();
return static_cast<const OpInlineResultImpl *>(this)->index();
}

OpResultImpl::~OpResultImpl() { assert(use_empty()); }
OpResultImpl::~OpResultImpl() {
if (!use_empty()) {
LOG(FATAL) << "Destoryed a op_result that is still in use. \n"
<< "The owner op type is:" << owner()->name();
}
}

Operation *OpResultImpl::owner() const {
Operation *OpResultImpl::owner() {
// For inline result, pointer offset index to obtain the address of op.
if (const auto *result = dyn_cast<OpInlineResultImpl>(this)) {
if (auto *result = dyn_cast<OpInlineResultImpl>(this)) {
result += result->index() + 1;
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(result));
return reinterpret_cast<Operation *>(result);
}
// For outline result, pointer offset outline_index to obtain the address of
// maximum inline result.
const OpOutlineResultImpl *outline_result =
(const OpOutlineResultImpl *)(this);
outline_result +=
(outline_result->outline_index_ - GetMaxInlineResultIndex());
auto *outline_result = static_cast<OpOutlineResultImpl *>(this);
outline_result += (outline_result->index() - MAX_INLINE_RESULT_IDX);
// The offset of the maximum inline result distance op is
// GetMaxInlineResultIndex.
const auto *inline_result =
reinterpret_cast<const OpInlineResultImpl *>(outline_result);
inline_result += (GetMaxInlineResultIndex() + 1);
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(inline_result));
auto *inline_result = reinterpret_cast<OpInlineResultImpl *>(outline_result);
inline_result += OUTLINE_RESULT_IDX;
return reinterpret_cast<Operation *>(inline_result);
}

} // namespace detail
Expand Down
21 changes: 7 additions & 14 deletions paddle/pir/core/op_result_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,20 @@ class OpResultImpl : public ValueImpl {
using ValueImpl::ValueImpl;

static bool classof(const ValueImpl &value) {
return value.kind() <= OUTLINE_OP_RESULT_INDEX;
return value.kind() <= OUTLINE_RESULT_IDX;
}

///
/// \brief Get the parent operation of this result.(op_ptr = value_ptr +
/// index)
///
Operation *owner() const;
Operation *owner();

///
/// \brief Get the result index of the operation result.
///
uint32_t index() const;

///
/// \brief Get the maximum number of results that can be stored inline.
///
static uint32_t GetMaxInlineResultIndex() {
return OUTLINE_OP_RESULT_INDEX - 1;
}

~OpResultImpl();
};

Expand All @@ -58,13 +51,13 @@ class OpInlineResultImpl : public OpResultImpl {
public:
OpInlineResultImpl(Type type, uint32_t result_index)
: OpResultImpl(type, result_index) {
if (result_index > GetMaxInlineResultIndex()) {
if (result_index > MAX_INLINE_RESULT_IDX) {
throw("Inline result index should not exceed MaxInlineResultIndex(5)");
}
}

static bool classof(const ValueImpl &value) {
return value.kind() < OUTLINE_OP_RESULT_INDEX;
return value.kind() < OUTLINE_RESULT_IDX;
}

uint32_t index() const { return kind(); }
Expand All @@ -77,15 +70,15 @@ class OpInlineResultImpl : public OpResultImpl {
class OpOutlineResultImpl : public OpResultImpl {
public:
OpOutlineResultImpl(Type type, uint32_t outline_index)
: OpResultImpl(type, OUTLINE_OP_RESULT_INDEX),
outline_index_(outline_index) {}
: OpResultImpl(type, OUTLINE_RESULT_IDX), outline_index_(outline_index) {}

static bool classof(const ValueImpl &value) {
return value.kind() == OUTLINE_OP_RESULT_INDEX;
return value.kind() == OUTLINE_RESULT_IDX;
}

uint32_t index() const { return outline_index_; }

private:
uint32_t outline_index_;
};

Expand Down
Loading