Skip to content

Commit a98f715

Browse files
authored
[Dynamic Shape] Add SplitGenerateShapeIntoShapeOpsPass (#60624)
* [Dynamic Shape] Add SplitGenerateShapeIntoShapeOpsPass * Fix compile error * Fix compile error
1 parent 504a590 commit a98f715

File tree

4 files changed

+419
-2
lines changed

4 files changed

+419
-2
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,13 @@ if(NOT CINN_ONLY)
3838
cinn_op_dialect
3939
op_dialect_vjp)
4040

41+
cinn_cc_library(
42+
split_generate_shape_into_shape_ops_pass
43+
SRCS
44+
split_generate_shape_into_shape_ops_pass.cc
45+
DEPS
46+
pir
47+
cinn_op_dialect
48+
op_dialect_vjp)
49+
4150
endif()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
16+
17+
#include "paddle/cinn/common/dim_expr_simplify.h"
18+
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
19+
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
20+
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
21+
#include "paddle/cinn/hlir/framework/pir/utils.h"
22+
#include "paddle/common/ddim.h"
23+
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
24+
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
25+
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
26+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
27+
#include "paddle/fluid/pir/drr/api/match_context.h"
28+
#include "paddle/pir/core/builtin_dialect.h"
29+
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
30+
#include "paddle/pir/pass/pass.h"
31+
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
32+
#include "paddle/pir/pattern_rewrite/pattern_match.h"
33+
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
34+
35+
namespace cinn {
36+
namespace dialect {
37+
namespace ir {
38+
39+
namespace {
40+
41+
struct TensorDimInShape {
42+
pir::Value value;
43+
int axis;
44+
};
45+
46+
struct TensorDimInData {
47+
pir::Value value;
48+
int axis;
49+
};
50+
51+
using TensorDim = std::variant<TensorDimInShape, TensorDimInData>;
52+
53+
using TensorDim4SymbolNameT =
54+
std::function<std::optional<TensorDim>(const std::string& symbol_name)>;
55+
56+
struct CachedDimExprToValueConverter {
57+
CachedDimExprToValueConverter(
58+
const TensorDim4SymbolNameT& TensorDim4SymbolNameVal,
59+
pir::PatternRewriter* rewriter_val)
60+
: TensorDim4SymbolName(TensorDim4SymbolNameVal), rewriter(rewriter_val) {}
61+
62+
TensorDim4SymbolNameT TensorDim4SymbolName;
63+
pir::PatternRewriter* rewriter;
64+
65+
// TODO(): Refactor to cached version if std::hash<symbol::DimExpr>() is
66+
// ready. std::unordered_map<symbol::DimExpr, pir::Value>
67+
// symbol_names2cached_value_;
68+
69+
pir::Value ConvertToValue(const symbol::DimExpr& dim_expr) {
70+
// TODO(): cache the returned value if std::hash<symbol::DimExpr>() is
71+
// ready
72+
return std::visit(
73+
[&](const auto& impl) { return ConvertToValueImpl(impl); },
74+
dim_expr.variant());
75+
}
76+
77+
pir::Value GetInputShapeByInputTensor(pir::Value input_tensor) {
78+
auto iter = tensor2shape_.find(input_tensor);
79+
if (iter == tensor2shape_.end()) {
80+
pir::Value input_shape =
81+
rewriter->Build<paddle::dialect::ShapeOp>(input_tensor).out();
82+
iter = tensor2shape_.emplace(input_tensor, input_shape).first;
83+
}
84+
return iter->second;
85+
}
86+
87+
private:
88+
std::unordered_map<pir::Value /*input tensor*/,
89+
pir::Value /*input shape tensor*/>
90+
tensor2shape_;
91+
92+
pir::Value ConvertToValueImpl(int64_t dim_expr) {
93+
return rewriter
94+
->Build<paddle::dialect::FullIntArrayOp>(std::vector{dim_expr},
95+
phi::DataType::INT64)
96+
.out();
97+
}
98+
99+
pir::Value ConvertToValueImpl(const std::string& symbol_name) {
100+
const auto& tensor_dim = TensorDim4SymbolName(symbol_name);
101+
PADDLE_ENFORCE(
102+
tensor_dim.has_value(),
103+
phi::errors::PreconditionNotMet(
104+
"symbol [%s] are not bound to any input of generate_shape op",
105+
symbol_name));
106+
return std::visit(
107+
[&](const auto& impl) { return ConvertTensorDimToValue(impl); },
108+
tensor_dim.value());
109+
}
110+
111+
pir::Value ConvertTensorDimToValue(const TensorDimInShape& tensor_dim) {
112+
pir::Value input_shape = GetInputShapeByInputTensor(tensor_dim.value);
113+
return ConvertTensorDimToValue(
114+
TensorDimInData{.value = input_shape, .axis = tensor_dim.axis});
115+
}
116+
117+
pir::Value ConvertTensorDimToValue(const TensorDimInData& tensor_dim) {
118+
return rewriter
119+
->Build<paddle::dialect::SliceOp>(
120+
tensor_dim.value,
121+
std::vector<int64_t>{0LL},
122+
std::vector<int64_t>{tensor_dim.axis},
123+
std::vector<int64_t>{tensor_dim.axis + 1},
124+
std::vector<int64_t>{},
125+
std::vector<int64_t>{})
126+
.out();
127+
}
128+
129+
pir::Value ConvertToValueImpl(
130+
const symbol::Negative<symbol::DimExpr>& dim_expr) {
131+
LOG(FATAL) << "Dead code. This logical should handled by "
132+
"ConvertToValueImpl(symbol::Add<symbol::DimExpr>)";
133+
}
134+
135+
pir::Value ConvertToValueImpl(
136+
const symbol::Reciprocal<symbol::DimExpr>& dim_expr) {
137+
LOG(FATAL) << "Dead code. This logical should handled by "
138+
"ConvertToValueImpl(symbol::Mul<symbol::DimExpr>)";
139+
}
140+
141+
pir::Value ConvertToValueImpl(const symbol::Add<symbol::DimExpr>& dim_expr) {
142+
const auto& [operands] = dim_expr;
143+
CHECK_GT(operands->size(), 0);
144+
pir::Value acc = ConvertToValue(operands->at(0));
145+
for (int i = 1; i < operands->size(); ++i) {
146+
if (operands->at(i).isa<symbol::Negative<symbol::DimExpr>>()) {
147+
const auto& [operand] =
148+
*operands->at(i).dyn_cast<symbol::Negative<symbol::DimExpr>>();
149+
pir::Value operand_value = ConvertToValue(operand);
150+
acc = rewriter->Build<paddle::dialect::SubtractOp>(acc, operand_value)
151+
.out();
152+
} else {
153+
pir::Value operand_value = ConvertToValue(operands->at(i));
154+
acc = rewriter->Build<paddle::dialect::AddOp>(acc, operand_value).out();
155+
}
156+
}
157+
return acc;
158+
}
159+
160+
pir::Value ConvertToValueImpl(const symbol::Mul<symbol::DimExpr>& dim_expr) {
161+
const auto& [operands] = dim_expr;
162+
CHECK_GT(operands->size(), 0);
163+
pir::Value prod = ConvertToValue(operands->at(0));
164+
for (int i = 1; i < operands->size(); ++i) {
165+
if (operands->at(i).isa<symbol::Reciprocal<symbol::DimExpr>>()) {
166+
const auto& [operand] =
167+
*operands->at(i).dyn_cast<symbol::Negative<symbol::DimExpr>>();
168+
pir::Value operand_value = ConvertToValue(operand);
169+
prod = rewriter->Build<paddle::dialect::DivideOp>(prod, operand_value)
170+
.out();
171+
} else {
172+
pir::Value operand_value = ConvertToValue(operands->at(i));
173+
prod = rewriter->Build<paddle::dialect::MultiplyOp>(prod, operand_value)
174+
.out();
175+
}
176+
}
177+
return prod;
178+
}
179+
180+
pir::Value ConvertToValueImpl(const symbol::Max<symbol::DimExpr>& dim_expr) {
181+
const auto& [operands] = dim_expr;
182+
CHECK_GT(operands->size(), 0);
183+
pir::Value max = ConvertToValue(operands->at(0));
184+
for (int i = 1; i < operands->size(); ++i) {
185+
pir::Value operand_value = ConvertToValue(operands->at(i));
186+
max = rewriter->Build<paddle::dialect::MaxOp>(max, operand_value).out();
187+
}
188+
return max;
189+
}
190+
191+
pir::Value ConvertToValueImpl(const symbol::Min<symbol::DimExpr>& dim_expr) {
192+
const auto& [operands] = dim_expr;
193+
CHECK_GT(operands->size(), 0);
194+
pir::Value min = ConvertToValue(operands->at(0));
195+
for (int i = 1; i < operands->size(); ++i) {
196+
pir::Value operand_value = ConvertToValue(operands->at(i));
197+
min = rewriter->Build<paddle::dialect::MinOp>(min, operand_value).out();
198+
}
199+
return min;
200+
}
201+
202+
pir::Value ConvertToValueImpl(
203+
const symbol::Broadcast<symbol::DimExpr>& dim_expr) {
204+
const auto& [operands] = dim_expr;
205+
CHECK_GT(operands->size(), 0);
206+
pir::Value broadcasted = ConvertToValue(operands->at(0));
207+
for (int i = 1; i < operands->size(); ++i) {
208+
pir::Value operand_value = ConvertToValue(operands->at(i));
209+
broadcasted = rewriter
210+
->Build<paddle::dialect::ShapeBroadcastOp>(
211+
broadcasted, operand_value)
212+
.out();
213+
}
214+
return broadcasted;
215+
}
216+
};
217+
218+
} // namespace
219+
220+
class SplitGenerateShapeIntoShapeOps
221+
: public pir::OpRewritePattern<cinn::dialect::GenerateShapeOp> {
222+
public:
223+
using pir::OpRewritePattern<cinn::dialect::GenerateShapeOp>::OpRewritePattern;
224+
225+
bool MatchAndRewrite(cinn::dialect::GenerateShapeOp op,
226+
pir::PatternRewriter& rewriter) const override {
227+
std::optional<pir::Value> out_replacement =
228+
GetOutReplacement(op, &rewriter);
229+
if (!out_replacement.has_value()) return false;
230+
rewriter.ReplaceAllUsesWith(op->result(0), out_replacement.value());
231+
return true;
232+
}
233+
234+
std::optional<pir::Value> GetOutReplacement(
235+
cinn::dialect::GenerateShapeOp op, pir::PatternRewriter* rewriter) const {
236+
std::vector<symbol::DimExpr> dim_exprs = GetOutDimExprs(op);
237+
TensorDim4SymbolNameT TensorDim4SymbolName =
238+
MakeGetterTensorDim4SymbolName(op);
239+
if (!TensorDim4SymbolName) return std::nullopt;
240+
CachedDimExprToValueConverter converter{TensorDim4SymbolName, rewriter};
241+
return GetValueOfRewritedOps(dim_exprs, &converter);
242+
}
243+
244+
TensorDim4SymbolNameT MakeGetterTensorDim4SymbolName(
245+
cinn::dialect::GenerateShapeOp op) const {
246+
std::unordered_map<std::string, TensorDim> symbol_name2tenso_dim{};
247+
const auto& attr_map = op->attributes();
248+
const auto& iter = attr_map.find("symbol_bindings");
249+
PADDLE_ENFORCE((iter != attr_map.end()),
250+
phi::errors::PreconditionNotMet(
251+
"attr symbol_bindings MUST in attribute map for [%s] op",
252+
op->name()));
253+
pir::Attribute attr = iter->second;
254+
auto* Convert =
255+
&cinn::dialect::GenerateShapeOp::ConvertAttributeToSymbolBindings;
256+
const auto& symbol_bindings = Convert(attr);
257+
PADDLE_ENFORCE(
258+
symbol_bindings.has_value(),
259+
phi::errors::PreconditionNotMet("attr symbol_bindings in op [%s] can "
260+
"not be converted to symbol bindings",
261+
op->name()));
262+
for (const auto& symbol_binding : symbol_bindings.value()) {
263+
InsertSymbolBinding(op, symbol_binding, &symbol_name2tenso_dim);
264+
}
265+
return [map = std::move(symbol_name2tenso_dim)](
266+
const std::string& symbol_name) -> std::optional<TensorDim> {
267+
auto iter = map.find(symbol_name);
268+
if (iter == map.end()) return std::nullopt;
269+
return iter->second;
270+
};
271+
}
272+
273+
void InsertSymbolBinding(
274+
cinn::dialect::GenerateShapeOp op,
275+
const cinn::dialect::GenerateShapeOp::SymbolBinding& symbol_binding,
276+
std::unordered_map<std::string, TensorDim>* symbol_name2tenso_dim) const {
277+
return std::visit(
278+
[&](const auto& impl) {
279+
return InsertSymbolBindingImpl(op, impl, symbol_name2tenso_dim);
280+
},
281+
symbol_binding);
282+
}
283+
284+
void InsertSymbolBindingImpl(
285+
cinn::dialect::GenerateShapeOp op,
286+
const cinn::dialect::GenerateShapeOp::DataSymbolBinding& symbol_binding,
287+
std::unordered_map<std::string, TensorDim>* symbol_name2tenso_dim) const {
288+
(*symbol_name2tenso_dim)[symbol_binding.symbol_name] = TensorDimInData{
289+
.value = op.operand_source(symbol_binding.input_tensor_idx),
290+
.axis = symbol_binding.input_tensor_dim_idx};
291+
}
292+
293+
void InsertSymbolBindingImpl(
294+
cinn::dialect::GenerateShapeOp op,
295+
const cinn::dialect::GenerateShapeOp::ShapeSymbolBinding& symbol_binding,
296+
std::unordered_map<std::string, TensorDim>* symbol_name2tenso_dim) const {
297+
(*symbol_name2tenso_dim)[symbol_binding.symbol_name] = TensorDimInShape{
298+
.value = op.operand_source(symbol_binding.input_tensor_idx),
299+
.axis = symbol_binding.input_tensor_dim_idx};
300+
}
301+
302+
std::vector<symbol::DimExpr> GetOutDimExprs(
303+
cinn::dialect::GenerateShapeOp op) const {
304+
const auto& attr_map = op->attributes();
305+
const auto& iter = attr_map.find("output_dim_exprs");
306+
PADDLE_ENFORCE(
307+
(iter != attr_map.end()),
308+
phi::errors::PreconditionNotMet(
309+
"attr output_dim_exprs MUST in attribute map for [%s] op",
310+
op->name()));
311+
pir::Attribute output_dim_exprs_attr = iter->second;
312+
PADDLE_ENFORCE(
313+
output_dim_exprs_attr.isa<pir::ArrayAttribute>(),
314+
phi::errors::PreconditionNotMet(
315+
"attr output_dim_exprs for [%s] op must be an pir::ArrayAttribute",
316+
op->name()));
317+
std::vector<symbol::DimExpr> ret{};
318+
const auto& output_dim_exprs =
319+
output_dim_exprs_attr.dyn_cast<pir::ArrayAttribute>();
320+
for (int i = 0; i < output_dim_exprs.size(); ++i) {
321+
const auto& attr = output_dim_exprs.at(i);
322+
const auto& opt_dim_expr = cinn::dialect::ConvertAttributeToDimExpr(attr);
323+
CHECK(opt_dim_expr.has_value());
324+
ret.emplace_back(opt_dim_expr.value());
325+
}
326+
return ret;
327+
}
328+
329+
pir::Value GetValueOfRewritedOps(
330+
const std::vector<symbol::DimExpr>& dim_exprs,
331+
CachedDimExprToValueConverter* converter) const {
332+
const std::vector<pir::Value>& values_from_dim_exprs =
333+
GetValuesOfRewritedOps(dim_exprs, converter);
334+
return converter->rewriter->Build<pir::CombineOp>(values_from_dim_exprs)
335+
.out();
336+
}
337+
338+
std::vector<pir::Value> GetValuesOfRewritedOps(
339+
const std::vector<symbol::DimExpr>& dim_exprs,
340+
CachedDimExprToValueConverter* converter) const {
341+
std::vector<pir::Value> ret;
342+
for (const auto& dim_expr : dim_exprs) {
343+
const auto& simplified = cinn::common::SimplifyDimExpr(dim_expr);
344+
pir::Value value = converter->ConvertToValue(simplified);
345+
ret.push_back(value);
346+
}
347+
return ret;
348+
}
349+
};
350+
351+
SplitGenerateShapeIntoShapeOpsPass::SplitGenerateShapeIntoShapeOpsPass()
352+
: pir::PatternRewritePass("split_generate_shape_into_shape_ops_pass", 1) {}
353+
354+
pir::RewritePatternSet SplitGenerateShapeIntoShapeOpsPass::InitializePatterns(
355+
pir::IrContext* context) {
356+
pir::RewritePatternSet ps(context);
357+
// elementwise ops
358+
ps.Add<SplitGenerateShapeIntoShapeOps>(context);
359+
return ps;
360+
}
361+
362+
bool SplitGenerateShapeIntoShapeOpsPass::CanApplyOn(pir::Operation* op) const {
363+
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
364+
}
365+
366+
} // namespace ir
367+
} // namespace dialect
368+
} // namespace cinn

0 commit comments

Comments
 (0)