Skip to content

Commit 08ec995

Browse files
authored
[DimExpr] Provide DimExpr Interface (#59992)
* Dim Expr Interface * CreateDimExprBuilder * DimExpr ValueShape * Remove one unit test * Fix cmake * DimExpr basic definition * DimExpr basic definition * Change DimExpr comments * Unit test passed * Fix CI error * Delete useless test * Change cmake
1 parent bdb3cdb commit 08ec995

File tree

8 files changed

+415
-0
lines changed

8 files changed

+415
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
16+
17+
namespace symbol {
18+
19+
DimExpr DimExpr::operator+(const DimExpr& other) const {
20+
return Add<DimExpr>(std::vector{*this, other});
21+
}
22+
23+
DimExpr DimExpr::operator-(const DimExpr& other) const {
24+
const DimExpr& neg = Negative<DimExpr>(other);
25+
return Add<DimExpr>(std::vector{*this, neg});
26+
}
27+
28+
DimExpr DimExpr::operator*(const DimExpr& other) const {
29+
return Mul<DimExpr>(std::vector{*this, other});
30+
}
31+
32+
DimExpr DimExpr::operator/(const DimExpr& other) const {
33+
const DimExpr& reciprocal = Reciprocal<DimExpr>(other);
34+
return Mul<DimExpr>(std::vector{*this, reciprocal});
35+
}
36+
37+
} // namespace symbol
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <functional>
18+
#include <memory>
19+
#include <optional>
20+
#include <string>
21+
#include <variant>
22+
#include <vector>
23+
24+
#include "glog/logging.h"
25+
26+
namespace symbol {
27+
28+
#define SYMBOL_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"
29+
30+
template <typename T>
31+
struct UnaryDimExpr {
32+
explicit UnaryDimExpr(const T& d) : data(std::make_shared<Data>(d)) {}
33+
struct Data {
34+
explicit Data(const T& d) : data(d) {}
35+
T data;
36+
};
37+
38+
const Data& operator*() const { return *data; }
39+
Data& operator*() { return *data; }
40+
const Data* operator->() const { return data.get(); }
41+
Data* operator->() { return data.get(); }
42+
43+
std::shared_ptr<Data> data;
44+
};
45+
46+
template <typename T>
47+
struct BinaryDimExpr {
48+
explicit BinaryDimExpr(const T& l, const T& r)
49+
: data(std::make_shared<Data>(l, r)) {}
50+
51+
struct Data {
52+
explicit Data(const T& l, const T& r) : lhs(l), rhs(r) {}
53+
T lhs;
54+
T rhs;
55+
};
56+
57+
const Data& operator*() const { return *data; }
58+
Data& operator*() { return *data; }
59+
const Data* operator->() const { return data.get(); }
60+
Data* operator->() { return data.get(); }
61+
62+
std::shared_ptr<Data> data;
63+
};
64+
65+
template <typename T>
66+
struct VariadicDimExpr {
67+
explicit VariadicDimExpr(const std::vector<T>& vec)
68+
: data(std::make_shared<Data>(vec)) {}
69+
70+
using Data = std::vector<T>;
71+
72+
const Data& operator*() const { return *data; }
73+
Data& operator*() { return *data; }
74+
const Data* operator->() const { return data.get(); }
75+
Data* operator->() { return data.get(); }
76+
77+
std::shared_ptr<Data> data;
78+
};
79+
80+
#define DEFINE_DIM_EXPR_SUBCLASS(class_name, base) \
81+
template <typename T> \
82+
struct class_name : public base<T> { \
83+
using base<T>::base; \
84+
};
85+
86+
DEFINE_DIM_EXPR_SUBCLASS(Negative, UnaryDimExpr);
87+
DEFINE_DIM_EXPR_SUBCLASS(Reciprocal, UnaryDimExpr);
88+
DEFINE_DIM_EXPR_SUBCLASS(Add, VariadicDimExpr);
89+
DEFINE_DIM_EXPR_SUBCLASS(Mul, VariadicDimExpr);
90+
DEFINE_DIM_EXPR_SUBCLASS(Max, VariadicDimExpr);
91+
DEFINE_DIM_EXPR_SUBCLASS(Min, VariadicDimExpr);
92+
DEFINE_DIM_EXPR_SUBCLASS(Broadcast, VariadicDimExpr);
93+
DEFINE_DIM_EXPR_SUBCLASS(Equal, BinaryDimExpr);
94+
DEFINE_DIM_EXPR_SUBCLASS(Broadcastable, BinaryDimExpr);
95+
96+
class DimExpr;
97+
98+
// DimExpr = std::int64_t
99+
// | std::string
100+
// | Negative DimExpr
101+
// | Reciprocal DimExpr
102+
// | Add DimExpr
103+
// | Mul DimExpr
104+
// | Max DimExpr
105+
// | Min DimExpr
106+
// | Broadcast DimExpr
107+
using DimExprBase = std::variant<std::int64_t,
108+
std::string,
109+
Negative<DimExpr>,
110+
Reciprocal<DimExpr>,
111+
Add<DimExpr>,
112+
Mul<DimExpr>,
113+
Max<DimExpr>,
114+
Min<DimExpr>,
115+
Broadcast<DimExpr>>;
116+
117+
class DimExpr : public DimExprBase {
118+
public:
119+
using DimExprBase::DimExprBase;
120+
121+
template <typename T>
122+
bool isa() const {
123+
return std::holds_alternative<T>(*this);
124+
}
125+
126+
template <typename T>
127+
const T& dyn_cast() const {
128+
return std::get<T>(*this);
129+
}
130+
131+
DimExpr operator+(const DimExpr& other) const;
132+
DimExpr operator-(const DimExpr& other) const;
133+
DimExpr operator*(const DimExpr& other) const;
134+
DimExpr operator/(const DimExpr& other) const;
135+
};
136+
137+
// DimExprConstraint = Equal DimExpr
138+
// | Broadcastable DimExpr
139+
using DimExprConstraint = std::variant<Equal<DimExpr>, Broadcastable<DimExpr>>;
140+
141+
// ValueShapeDimExprs = tShape [DimExpr] | tValue [DimExpr]
142+
template <typename T>
143+
class ValueShape {
144+
public:
145+
explicit ValueShape(const std::vector<T>& shape)
146+
: shape_(shape), value_(std::nullopt) {}
147+
148+
static ValueShape MakeConsistentValue(const std::vector<T>& value) {
149+
T size(std::int64_t(value.size()));
150+
return ValueShape(std::vector<T>{size}, value);
151+
}
152+
153+
const std::optional<std::vector<T>>& shape() const { return shape_; }
154+
const std::optional<std::vector<T>>& value() const { return value_; }
155+
156+
private:
157+
explicit ValueShape(const std::vector<T>& shape, const std::vector<T>& value)
158+
: shape_(shape), value_(value) {}
159+
160+
std::optional<std::vector<T>> shape_;
161+
std::optional<std::vector<T>> value_;
162+
};
163+
164+
using ValueShapeDimExprs = ValueShape<DimExpr>;
165+
166+
} // namespace symbol
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/pir/dialect/shape/utils/dim_expr_builder.h"
16+
17+
namespace symbol {
18+
19+
using BroadcastDimExpr = Broadcast<DimExpr>;
20+
21+
DimExpr DimExprBuilder::ConstSize(std::int64_t dim) { SYMBOL_NOT_IMPLEMENTED; }
22+
23+
DimExpr DimExprBuilder::Symbol(const std::string& symbol_name) {
24+
SYMBOL_NOT_IMPLEMENTED;
25+
}
26+
27+
DimExpr DimExprBuilder::Add(const DimExpr& lhs, const DimExpr& rhs) {
28+
return lhs + rhs;
29+
}
30+
31+
DimExpr DimExprBuilder::Any(const DimExpr& lhs, const DimExpr& rhs) {
32+
SYMBOL_NOT_IMPLEMENTED;
33+
}
34+
35+
DimExpr DimExprBuilder::Mul(const DimExpr& lhs, const DimExpr& rhs) {
36+
return lhs * rhs;
37+
}
38+
39+
DimExpr DimExprBuilder::Div(const DimExpr& lhs, const DimExpr& rhs) {
40+
return lhs / rhs;
41+
}
42+
43+
DimExpr DimExprBuilder::Max(const DimExpr& lhs, const DimExpr& rhs) {
44+
SYMBOL_NOT_IMPLEMENTED;
45+
}
46+
47+
DimExpr DimExprBuilder::Min(const DimExpr& lhs, const DimExpr& rhs) {
48+
SYMBOL_NOT_IMPLEMENTED;
49+
}
50+
51+
DimExpr DimExprBuilder::Broadcast(const DimExpr& lhs, const DimExpr& rhs) {
52+
return BroadcastDimExpr(std::vector{lhs, rhs});
53+
}
54+
55+
std::vector<DimExpr> DimExprBuilder::ConstShape(
56+
const std::vector<std::int64_t>& dims) {
57+
SYMBOL_NOT_IMPLEMENTED;
58+
}
59+
60+
void DimExprBuilder::CstrBroadcastable(const DimExpr& lhs, const DimExpr& rhs) {
61+
SYMBOL_NOT_IMPLEMENTED;
62+
}
63+
64+
void DimExprBuilder::CstrBroadcastable(const std::vector<DimExpr>& lhs,
65+
const std::vector<DimExpr>& rhs) {
66+
SYMBOL_NOT_IMPLEMENTED;
67+
}
68+
69+
void DimExprBuilder::CstrEq(const DimExpr& lhs, const DimExpr& rhs) {
70+
constraints_->emplace_back(Equal<DimExpr>(lhs, rhs));
71+
}
72+
73+
void DimExprBuilder::CstrEq(const std::vector<DimExpr>& lhs,
74+
const std::vector<DimExpr>& rhs) {
75+
SYMBOL_NOT_IMPLEMENTED;
76+
}
77+
78+
std::vector<DimExpr> DimExprBuilder::Concat(const std::vector<DimExpr>& lhs,
79+
const std::vector<DimExpr>& rhs) {
80+
SYMBOL_NOT_IMPLEMENTED;
81+
}
82+
83+
std::pair<std::vector<DimExpr>, std::vector<DimExpr>> DimExprBuilder::SplitAt(
84+
const std::vector<DimExpr>, int index) {
85+
SYMBOL_NOT_IMPLEMENTED;
86+
}
87+
88+
const std::vector<DimExprConstraint>& DimExprBuilder::constaints() const {
89+
SYMBOL_NOT_IMPLEMENTED;
90+
}
91+
92+
} // namespace symbol
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
18+
19+
namespace symbol {
20+
21+
class DimExprBuilder {
22+
public:
23+
explicit DimExprBuilder(std::vector<DimExprConstraint>* constraints)
24+
: constraints_(constraints) {}
25+
26+
DimExpr ConstSize(std::int64_t dim);
27+
DimExpr Symbol(const std::string& symbol_name);
28+
DimExpr Add(const DimExpr& lhs, const DimExpr& rhs);
29+
DimExpr Any(const DimExpr& lhs, const DimExpr& rhs);
30+
DimExpr Mul(const DimExpr& lhs, const DimExpr& rhs);
31+
DimExpr Div(const DimExpr& lhs, const DimExpr& rhs);
32+
DimExpr Max(const DimExpr& lhs, const DimExpr& rhs);
33+
DimExpr Min(const DimExpr& lhs, const DimExpr& rhs);
34+
DimExpr Broadcast(const DimExpr& lhs, const DimExpr& rhs);
35+
std::vector<DimExpr> ConstShape(const std::vector<std::int64_t>& dims);
36+
37+
void CstrBroadcastable(const DimExpr& lhs, const DimExpr& rhs);
38+
void CstrBroadcastable(const std::vector<DimExpr>& lhs,
39+
const std::vector<DimExpr>& rhs);
40+
void CstrEq(const DimExpr& lhs, const DimExpr& rhs);
41+
void CstrEq(const std::vector<DimExpr>& lhs, const std::vector<DimExpr>& rhs);
42+
43+
std::vector<DimExpr> Concat(const std::vector<DimExpr>& lhs,
44+
const std::vector<DimExpr>& rhs);
45+
std::pair<std::vector<DimExpr>, std::vector<DimExpr>> SplitAt(
46+
const std::vector<DimExpr>, int index);
47+
48+
const std::vector<DimExprConstraint>& constaints() const;
49+
50+
private:
51+
std::vector<DimExprConstraint>* constraints_;
52+
};
53+
54+
} // namespace symbol

paddle/pir/dialect/shape/utils/shape_utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ ShapeConstraintIRAnalysis::GetOrCreateSymbolicDimsForRankedValue(
147147
return value_to_sym_dims_.at(value);
148148
}
149149

150+
symbol::DimExprBuilder ShapeConstraintIRAnalysis::CreateDimExprBuilder() {
151+
return symbol::DimExprBuilder(&constraints_);
152+
}
153+
150154
ShapeAnalysisManager& ShapeAnalysisManager::Instance() {
151155
static ShapeAnalysisManager instance;
152156
return instance;

paddle/pir/dialect/shape/utils/shape_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/pir/dialect/shape/utils/dim_expr_builder.h"
1718
#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h"
1819
#include "paddle/pir/dialect/shape/utils/symbol_table.h"
1920

@@ -46,6 +47,8 @@ class IR_API ShapeAnalysis {
4647

4748
// Returns true if the two value have the same number elements.
4849
virtual bool IsSameNumElements(Value lhs, Value rhs);
50+
51+
virtual symbol::DimExprBuilder CreateDimExprBuilder() = 0;
4952
};
5053

5154
// A subclass to impement `ShapeAnalysis` on buffer level.
@@ -71,6 +74,8 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis {
7174
Value rhs,
7275
std::vector<int> rhs_dim_idxs) override;
7376

77+
symbol::DimExprBuilder CreateDimExprBuilder() override;
78+
7479
private:
7580
// The operation this analysis runs on.
7681
ModuleOp m_;
@@ -80,6 +85,7 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis {
8085
// dimension size of the memref value.
8186
std::unordered_map<Value, std::vector<shape::SymbolicDimOp>>
8287
value_to_sym_dims_;
88+
std::vector<symbol::DimExprConstraint> constraints_;
8389

8490
public:
8591
explicit ShapeConstraintIRAnalysis(std::shared_ptr<pir::Program>&& program)

0 commit comments

Comments
 (0)