Skip to content

Commit 5e8bfa1

Browse files
committed
[PIR+CINN]Refactor CINNGroupLoweringPass into ::pir::Pass
1 parent 137ead7 commit 5e8bfa1

20 files changed

+161
-179
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

+2-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1+
add_subdirectory(group_merge)
2+
13
if(NOT CINN_ONLY)
2-
cinn_cc_library(
3-
op_with_group_merge_pass
4-
SRCS
5-
group_with_group_merge_pass.cc
6-
op_with_group_merge_pass.cc
7-
cinn_group_lowering_pass.cc
8-
tensor_node.cc
9-
DEPS
10-
op_dialect_vjp
11-
pir_compiler
12-
cinn_runtime_dialect)
134

145
cinn_cc_library(
156
pd_to_cinn_pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
if(NOT CINN_ONLY)
2+
cinn_cc_library(
3+
op_with_group_merge_pass
4+
SRCS
5+
group_with_group_merge_pass.cc
6+
op_with_group_merge_pass.cc
7+
cinn_group_lowering_pass.cc
8+
tensor_node.cc
9+
DEPS
10+
op_dialect_vjp
11+
pir_compiler
12+
cinn_runtime_dialect)
13+
endif()

paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc

+73-92
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,24 @@
1414

1515
#pragma once
1616

17-
#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h"
17+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h"
1818

1919
#include <unordered_map>
2020

2121
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
2222
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
2323
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
24-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"
24+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.h"
2525
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
2626
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
2727
#include "paddle/cinn/hlir/framework/pir_compiler.h"
2828
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
29+
#include "paddle/pir/core/program.h"
2930
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
31+
#include "paddle/pir/pass/pass_registry.h"
32+
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
3033

31-
namespace cinn {
32-
namespace dialect {
33-
namespace ir {
34+
namespace {
3435

3536
std::vector<pir::Value> GetBlockOutsideInput(
3637
const std::vector<pir::Operation*> op_list) {
@@ -123,113 +124,93 @@ std::vector<pir::Operation*> GetOutputOpList(
123124
return vec_res;
124125
}
125126

126-
std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
127-
::pir::IrContext* ctx = ::pir::IrContext::Instance();
127+
class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
128+
public:
129+
using pir::OpRewritePattern<cinn::dialect::GroupOp>::OpRewritePattern;
128130

129-
ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
130-
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
131-
ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>();
131+
bool MatchAndRewrite(cinn::dialect::GroupOp group_op,
132+
pir::PatternRewriter& rewriter) const override {
133+
::pir::IrContext* ctx = ::pir::IrContext::Instance();
134+
auto target = cinn::common::DefaultNVGPUTarget();
135+
auto* program = group_op->GetParentProgram();
136+
// TODO(Aurelius84): Remove scope after cleaning PirCompiler usless Build
137+
// Interface
138+
auto scope = std::make_shared<cinn::hlir::framework::Scope>();
132139

133-
std::string jit_op_name = cinn::dialect::JitKernelOp::name();
134-
::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name);
140+
VLOG(4) << "start Lowering Group Op: " << group_op;
135141

136-
auto ir_program = std::make_unique<::pir::Program>(ctx);
137-
std::unordered_map<pir::Value, pir::Value> value_map;
142+
// op fusion
143+
auto op_fusion = cinn::dialect::ir::OpFusionPassInternal(
144+
GetOpListNotIncludeYield(group_op.ops()),
145+
GetOutputOpList(group_op.ops()));
138146

139-
auto target = cinn::common::DefaultNVGPUTarget();
140-
auto scope = cinn::hlir::framework::BuildScope(target, *program);
147+
// fusion merge
148+
auto group_list =
149+
cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion);
141150

142-
for (auto it = program->block()->begin(); it != program->block()->end();
143-
++it) {
144-
if (it->isa<cinn::dialect::GroupOp>()) {
145-
// GetOpList and Call cinn CodeGen
146-
auto group_op = it->dyn_cast<cinn::dialect::GroupOp>();
151+
for (auto group : group_list) {
152+
auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>(
153+
*program, target, scope);
154+
cinn::hlir::framework::PirCompilerManager::Instance().insert(ir_compiler);
147155

148-
// op fusion
149-
auto op_fusion = cinn::dialect::ir::OpFusionPassInternal(
150-
GetOpListNotIncludeYield(group_op.ops()),
151-
GetOutputOpList(group_op.ops()));
156+
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group});
157+
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
158+
{cinn::dialect::JitKernelOp::kAttrName,
159+
cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])},
160+
};
152161

153-
// fusion merge
154-
auto group_list =
155-
cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion);
162+
// Generate jit kernel op input and output
163+
auto vec_ins = GetBlockOutsideInput(group->ops);
156164

157-
// using yield op to sort
158-
std::unordered_map<::pir::Value, size_t> value2id;
159-
auto yeild_op = group_op.ops().back();
160-
for (size_t i = 0; i < yeild_op->num_operands(); ++i) {
161-
value2id[yeild_op->operand_source(i)] = i;
165+
std::vector<pir::Type> vec_types;
166+
for (size_t i = 0; i < group->output_values.size(); ++i) {
167+
vec_types.push_back(group->output_values[i].type());
162168
}
163169

164-
for (auto group : group_list) {
165-
auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>(
166-
*program, target, scope);
167-
hlir::framework::PirCompilerManager::Instance().insert(ir_compiler);
168-
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group});
169-
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
170-
{cinn::dialect::JitKernelOp::kAttrName,
171-
cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])},
172-
};
173-
174-
// Generate jit kernel op input and output
175-
auto vec_ins = GetBlockOutsideInput(group->ops);
176-
177-
std::vector<pir::Value> vec_new_ins;
178-
for (size_t i = 0; i < vec_ins.size(); ++i) {
179-
vec_new_ins.push_back(value_map.at(vec_ins[i]));
180-
}
181-
182-
std::unordered_map<size_t, size_t> codegen2orig;
183-
184-
std::vector<pir::Type> vec_types;
185-
for (size_t i = 0; i < group->output_values.size(); ++i) {
186-
vec_types.push_back(group->output_values[i].type());
187-
}
170+
auto jit_kernel_op = rewriter.Build<cinn::dialect::JitKernelOp>(
171+
vec_ins, op_attrs, vec_types);
172+
for (size_t i = 0; i < jit_kernel_op.num_results(); ++i) {
173+
rewriter.ReplaceAllUsesWith(group->output_values[i],
174+
jit_kernel_op.result(i));
175+
}
176+
}
177+
rewriter.EraseOp(group_op);
178+
return true;
179+
}
180+
};
188181

189-
::pir::Operation* cinn_op =
190-
::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info);
182+
class CinnGroupLoweringPass : public pir::PatternRewritePass {
183+
public:
184+
CinnGroupLoweringPass() : pir::PatternRewritePass("cinn_group_lowering", 1) {}
191185

192-
for (size_t i = 0; i < cinn_op->num_results(); ++i) {
193-
auto find_it = value2id.find(group->output_values[i]);
194-
if (find_it == value2id.end()) {
195-
value_map[group->output_values[i]] = cinn_op->result(i);
196-
} else {
197-
value_map[group_op.result(find_it->second)] = cinn_op->result(i);
198-
}
199-
}
186+
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
187+
context->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
188+
context->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
189+
context->GetOrRegisterDialect<paddle::dialect::KernelDialect>();
200190

201-
ir_program->block()->push_back(cinn_op);
202-
}
191+
pir::RewritePatternSet ps(context);
192+
ps.Add<GroupOpPattern>(context);
203193

204-
} else {
205-
std::vector<pir::Value> vec_ins;
194+
return ps;
195+
}
206196

207-
for (size_t i = 0; i < it->num_operands(); ++i) {
208-
if (it->operand_source(i)) {
209-
vec_ins.push_back(value_map.at(it->operand_source(i)));
210-
} else {
211-
vec_ins.push_back(it->operand_source(i));
212-
}
213-
}
197+
bool CanApplyOn(pir::Operation* op) const override {
198+
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
199+
}
200+
};
214201

215-
std::vector<pir::Type> vec_types;
216-
for (size_t i = 0; i < it->num_results(); ++i) {
217-
vec_types.push_back(it->result(i).type());
218-
}
202+
} // namespace
219203

220-
::pir::OpInfo info1 = ctx->GetRegisteredOpInfo(it->name());
221-
::pir::Operation* op =
222-
::pir::Operation::Create(vec_ins, it->attributes(), vec_types, info1);
204+
namespace cinn {
205+
namespace dialect {
206+
namespace ir {
223207

224-
ir_program->block()->push_back(op);
225-
for (size_t i = 0; i < it->num_results(); ++i) {
226-
value_map[it->result(i)] = op->result(i);
227-
}
228-
}
229-
}
230-
return ir_program;
208+
std::unique_ptr<::pir::Pass> CreateCinnGroupLoweringPass() {
209+
return std::make_unique<CinnGroupLoweringPass>();
231210
}
232211

233212
} // namespace ir
234213
} // namespace dialect
235214
} // namespace cinn
215+
216+
REGISTER_IR_PASS(cinn_group_lowering, CinnGroupLoweringPass);

paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
#pragma once
1616

17-
#include "paddle/pir/core/program.h"
17+
#include <memory>
18+
#include "paddle/pir/pass/pass.h"
1819

1920
namespace cinn {
2021
namespace dialect {
2122
namespace ir {
22-
23-
std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program);
24-
23+
std::unique_ptr<::pir::Pass> CreateCinnGroupLoweringPass();
2524
} // namespace ir
2625
} // namespace dialect
2726
} // namespace cinn

paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
#include <set>
1717
#include <unordered_map>
1818

19-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h"
19+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_group.h"
2020
#include "paddle/pir/core/ir_printer.h"
2121
#include "paddle/pir/core/value.h"
2222

23-
#include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h"
24-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"
23+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass_utils.h"
24+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.h"
2525

26-
#include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h"
27-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
26+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h"
27+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
2828
#include "paddle/phi/core/flags.h"
2929

3030
#include "paddle/cinn/common/is_reachable_predicator.h"

paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass_utils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
#pragma once
1616

17-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h"
18-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
17+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_group.h"
18+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
1919

2020
namespace cinn {
2121
namespace dialect {

paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include "paddle/pir/core/value.h"
2828
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
2929

30-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
30+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
3131

3232
namespace cinn {
3333
namespace dialect {

paddle/cinn/hlir/dialect/operator/transforms/op_group.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_group.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
#include <memory>
1818

19-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_node.h"
20-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
19+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h"
20+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
2121

2222
namespace cinn {
2323
namespace dialect {

paddle/cinn/hlir/dialect/operator/transforms/op_node.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#pragma once
1616

1717
#include <memory>
18-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
19-
#include "paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h"
18+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
19+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.h"
2020
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
2121
#include "paddle/pir/core/operation.h"
2222

paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.h"
1616

1717
#include <limits.h>
1818
#include <memory>

paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#pragma once
1616

17-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h"
17+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
1818
#include "paddle/pir/core/program.h"
1919

2020
namespace cinn {

paddle/cinn/hlir/dialect/operator/transforms/tensor_node.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
#pragma once
1616

17-
#include "paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h"
17+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.h"
1818

19-
#include "paddle/cinn/hlir/dialect/operator/transforms/op_node.h"
19+
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h"
2020

2121
namespace cinn {
2222
namespace dialect {

paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@ namespace dialect {
2424

2525
const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};
2626

27+
void JitKernelOp::Build(::pir::Builder& builder,
28+
const pir::OperationArgument& argument,
29+
const std::vector<::pir::Value>& x,
30+
const ::pir::AttributeMap& attributes,
31+
const std::vector<::pir::Type>& out_types) {
32+
VLOG(4) << "Start build JitKernelOp";
33+
34+
VLOG(4) << "Builder construction inputs";
35+
argument.AddInputs(x);
36+
37+
VLOG(4) << "Builder construction attributes";
38+
argument.AddAttributes(attributes);
39+
40+
VLOG(4) << "Builder construction outputs";
41+
argument.AddOutputs(out_types.begin(), out_types.end());
42+
}
43+
2744
void JitKernelOp::VerifySig() {
2845
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";
2946

0 commit comments

Comments
 (0)