Skip to content

Commit de9818f

Browse files
authored
[PIR] Unify dyn_cast interface with pir::cast (#57463)
* PR comment * unify dyn_cast_interface * rm detail
1 parent 820a387 commit de9818f

File tree

15 files changed

+124
-134
lines changed

15 files changed

+124
-134
lines changed

paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const phi::Place& AllocatedDenseTensorType::place() const {
2121
return storage()->place_;
2222
}
2323

24-
const pir::Type& AllocatedDenseTensorType::dtype() const {
24+
pir::Type AllocatedDenseTensorType::dtype() const {
2525
return storage()->dense_tensor_type_.dtype();
2626
}
2727

paddle/fluid/pir/dialect/kernel/ir/kernel_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class AllocatedDenseTensorType
5151

5252
const phi::Place &place() const;
5353

54-
const pir::Type &dtype() const;
54+
pir::Type dtype() const;
5555

5656
const phi::DDim &dims() const;
5757

paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class OpYamlInfoInterface : public pir::OpInterfaceBase<OpYamlInfoInterface> {
4040
Model() : Concept(GetOpInfo) {}
4141
};
4242

43+
/// Constructor
4344
OpYamlInfoInterface(pir::Operation *op, Concept *impl)
4445
: pir::OpInterfaceBase<OpYamlInfoInterface>(op), impl_(impl) {}
4546

paddle/fluid/pir/dialect/operator/interface/vjp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
4343
Model() : Concept(Vjp) {}
4444
};
4545

46+
/// Constructor
4647
VjpInterface(pir::Operation* op, Concept* impl)
4748
: pir::OpInterfaceBase<VjpInterface>(op), impl_(impl) {}
4849

paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {
6767
if (auto tensor_type = type.dyn_cast<DenseTensorType>()) {
6868
os << "tensor<";
6969
for (auto d : phi::vectorize(tensor_type.dims())) {
70-
pir::ShapedTypeInterface::isDynamic(d) ? os << "?" : os << d;
70+
pir::ShapedTypeInterface::IsDynamic(d) ? os << "?" : os << d;
7171
os << "x";
7272
}
7373
tensor_type.dtype().Print(os);

paddle/pir/core/builtin_type.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
namespace pir {
1818
std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }
1919

20-
const pir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
20+
pir::Type DenseTensorType::dtype() const { return storage()->dtype_; }
2121

2222
const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const {
2323
return storage()->dims_;

paddle/pir/core/builtin_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DenseTensorType : public Type::TypeBase<DenseTensorType,
5959
public:
6060
using Base::Base;
6161

62-
const Type &dtype() const;
62+
Type dtype() const;
6363

6464
const DenseTensorTypeStorage::Dim &dims() const;
6565

paddle/pir/core/builtin_type_interfaces.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
namespace pir {
1919

20-
Type ShapedTypeInterface::getElementType() const {
21-
return impl_->get_element_type_(*this);
20+
Type ShapedTypeInterface::GetElementType() const {
21+
return impl_->get_element_type(*this);
2222
}
2323

24-
phi::DDim ShapedTypeInterface::getShape() const {
25-
return impl_->get_shape_(*this);
24+
phi::DDim ShapedTypeInterface::GetShape() const {
25+
return impl_->get_shape(*this);
2626
}
2727

2828
} // namespace pir

paddle/pir/core/builtin_type_interfaces.h

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,16 @@
1414

1515
#pragma once
1616

17+
#include <algorithm>
1718
#include <vector>
19+
1820
#include "paddle/phi/core/tensor_base.h"
1921
#include "paddle/pir/core/cast_utils.h"
2022
#include "paddle/pir/core/enforce.h"
2123
#include "paddle/pir/core/type.h"
2224

2325
namespace pir {
2426

25-
namespace detail {
26-
27-
template <typename RangeT>
28-
constexpr auto begin_impl(RangeT &&range)
29-
-> decltype(std::begin(std::forward<RangeT>(range))) {
30-
return std::begin(std::forward<RangeT>(range));
31-
}
32-
33-
template <typename RangeT>
34-
constexpr auto end_impl(RangeT &&range)
35-
-> decltype(std::end(std::forward<RangeT>(range))) {
36-
return std::end(std::forward<RangeT>(range));
37-
}
38-
39-
template <typename RangeT>
40-
constexpr auto adl_begin(RangeT &&range)
41-
-> decltype(begin_impl(std::forward<RangeT>(range))) {
42-
return begin_impl(std::forward<RangeT>(range));
43-
}
44-
45-
template <typename RangeT>
46-
constexpr auto adl_end(RangeT &&range)
47-
-> decltype(end_impl(std::forward<RangeT>(range))) {
48-
return end_impl(std::forward<RangeT>(range));
49-
}
50-
51-
template <typename R, typename UnaryPredicate>
52-
bool any_of(R &&Range, UnaryPredicate P) {
53-
return std::any_of(adl_begin(Range), adl_end(Range), P);
54-
}
55-
56-
template <typename R, typename UnaryPredicate>
57-
auto count_if(R &&Range, UnaryPredicate P) {
58-
return std::count_if(adl_begin(Range), adl_end(Range), P);
59-
}
60-
61-
} // namespace detail
62-
6327
class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
6428
public:
6529
using DDim = phi::DDim;
@@ -68,10 +32,10 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
6832
/// Defined these methods with the interface.
6933
explicit Concept(DataType (*get_element_type)(Type),
7034
DDim (*get_shape)(Type))
71-
: get_element_type_(get_element_type), get_shape_(get_shape) {}
35+
: get_element_type(get_element_type), get_shape(get_shape) {}
7236

73-
DataType (*get_element_type_)(Type);
74-
DDim (*get_shape_)(Type);
37+
DataType (*get_element_type)(Type);
38+
DDim (*get_shape)(Type);
7539
};
7640

7741
template <class ConcreteType>
@@ -88,18 +52,27 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
8852
};
8953

9054
/// Constructor
55+
ShapedTypeInterface(std::nullptr_t) // NOLINT
56+
: TypeInterfaceBase<ShapedTypeInterface>(Type()), impl_(nullptr) {}
57+
58+
explicit ShapedTypeInterface(Type type = Type())
59+
: TypeInterfaceBase<ShapedTypeInterface>(type),
60+
impl_(type
61+
? type.abstract_type().GetInterfaceImpl<ShapedTypeInterface>()
62+
: nullptr) {}
63+
9164
ShapedTypeInterface(Type type, Concept *impl)
9265
: TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}
9366

9467
///
9568
/// \brief Get the element type.
9669
///
97-
DataType getElementType() const;
70+
DataType GetElementType() const;
9871

9972
///
10073
/// \brief Get the shape of this type.
10174
///
102-
DDim getShape() const;
75+
DDim GetShape() const;
10376

10477
///
10578
/// \brief kDynamic
@@ -109,62 +82,65 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
10982
///
11083
/// \brief Check whether this type is ranked, currently return true.
11184
///
112-
bool hasRank() const { return true; }
85+
bool HasRank() const { return true; }
11386

11487
///
11588
/// If this is a ranked type, return the rank. Otherwise, abort.
11689
///
117-
int64_t getRank() const {
118-
IR_ENFORCE((*this).hasRank(), "Cannot query rank of unranked shaped type.");
119-
return (*this).getShape().size();
90+
int64_t GetRank() const {
91+
IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type.");
92+
return (*this).GetShape().size();
12093
}
12194

12295
///
12396
/// \brief Check whether the given dimension size is a dynamic dimension.
12497
///
125-
static constexpr bool isDynamic(int64_t dValue) { return dValue == kDynamic; }
98+
static constexpr bool IsDynamic(int64_t dValue) { return dValue == kDynamic; }
12699

127100
///
128101
/// \brief Check whether the given shape has any size indicating a dynamic
129102
/// dimension.
130103
///
131-
static bool isDynamicShape(DDim dSizes) {
132-
return detail::any_of(vectorize(dSizes),
133-
[](int64_t dSize) { return isDynamic(dSize); });
104+
static bool IsDynamicShape(DDim sizes) {
105+
auto size_vec = vectorize(sizes);
106+
return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_vec) {
107+
return IsDynamic(size_vec);
108+
});
134109
}
135110

136111
///
137112
/// \brief Check whether shape has any size indicating a dynamic dimension.
138113
///
139-
bool hasStaticShape() const {
140-
return (*this).hasRank() && !isDynamicShape((*this).getShape());
114+
bool HasStaticShape() const {
115+
return (*this).HasRank() && !IsDynamicShape((*this).GetShape());
141116
}
142117

143118
///
144119
/// \brief Check whether the given dimension has a dynamic size.Aborts for
145120
/// unranked types.
146121
///
147-
bool isDynamicDim(unsigned idx) const {
148-
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
149-
return ShapedTypeInterface::isDynamic((*this).getShape()[idx]);
122+
bool IsDynamicDim(unsigned idx) const {
123+
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
124+
return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]);
150125
}
151126

152127
///
153128
/// \brief Get the number of dimensions with dynamic size for a ranked type.
154129
/// Aborts for unranked types.
155130
///
156-
int64_t getNumDynamicDims() const {
157-
return detail::count_if(vectorize((*this).getShape()),
158-
ShapedTypeInterface::isDynamic);
131+
int64_t GetNumDynamicDims() const {
132+
auto shape_vec = vectorize((*this).GetShape());
133+
return std::count_if(
134+
shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic);
159135
}
160136

161137
///
162138
/// \brief Get the size of the specified dimension for a ranked type. Aborts
163139
/// for unranked types.
164140
///
165-
int64_t getDimSize(unsigned idx) const {
166-
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
167-
return (*this).getShape()[idx];
141+
int64_t GetDimSize(unsigned idx) const {
142+
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
143+
return (*this).GetShape()[idx];
168144
}
169145

170146
private:

paddle/pir/core/op_base.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,18 @@ class OpInterfaceBase : public OpBase {
9191
public:
9292
explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
9393

94-
// Accessor for the ID of this interface.
94+
///
95+
/// \brief Accessor for the ID of this interface.
96+
///
9597
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
9698

99+
///
100+
/// \brief Checking if the given object defines the concrete interface.
101+
///
102+
static bool classof(Operation *op) {
103+
return op->HasInterface<ConcreteInterface>();
104+
}
105+
97106
static ConcreteInterface dyn_cast(Operation *op) {
98107
if (op && op->HasInterface<ConcreteInterface>()) {
99108
return ConcreteInterface(

paddle/pir/core/operation.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,22 +172,22 @@ class IR_API alignas(8) Operation final {
172172
detail::OpOperandImpl *op_operand_impl(uint32_t index);
173173
const detail::OpOperandImpl *op_operand_impl(uint32_t index) const;
174174

175-
template <typename T, typename Enabler = void>
175+
template <typename To, typename Enabler = void>
176176
struct CastUtil {
177-
static T call(Operation *op) {
178-
throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
177+
static To call(Operation *op) {
178+
throw("Can't dyn_cast to To, To should be a Op or Trait or Interface");
179179
}
180180
};
181181

182182
// Allow access to 'SetParent'.
183183
friend class Block;
184184
void SetParent(Block *parent, const Block::Iterator &position);
185185

186-
template <typename T>
186+
template <typename To>
187187
struct CastUtil<
188-
T,
189-
typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
190-
static T call(Operation *op) { return T::dyn_cast(op); }
188+
To,
189+
typename std::enable_if<std::is_base_of<OpBase, To>::value>::type> {
190+
static To call(Operation *op) { return To::dyn_cast(op); }
191191
};
192192

193193
AttributeMap attributes_;

paddle/pir/core/storage_manager_support.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ class StorageHelperBase : public BaseT {
6363
using InterfaceList =
6464
typename Filter<TypeInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
6565

66+
static ConcreteT dyn_cast_impl(BaseT type) {
67+
if (type && type.abstract_type().type_id() == TypeId::get<ConcreteT>()) {
68+
return ConcreteT(type.storage());
69+
}
70+
return ConcreteT(nullptr);
71+
}
72+
6673
///
6774
/// \brief Access to the storage instance.
6875
///

0 commit comments

Comments
 (0)