|
14 | 14 |
|
15 | 15 | #pragma once
|
16 | 16 |
|
17 |
| -#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" |
| 17 | +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h" |
18 | 18 |
|
19 | 19 | #include <unordered_map>
|
20 | 20 |
|
21 | 21 | #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
|
22 | 22 | #include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
|
23 | 23 | #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
|
24 |
| -#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" |
| 24 | +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.h" |
25 | 25 | #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
|
26 | 26 | #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
|
27 | 27 | #include "paddle/cinn/hlir/framework/pir_compiler.h"
|
28 | 28 | #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
|
| 29 | +#include "paddle/pir/core/program.h" |
29 | 30 | #include "paddle/pir/dialect/control_flow/ir/cf_op.h"
|
| 31 | +#include "paddle/pir/pass/pass_registry.h" |
| 32 | +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" |
30 | 33 |
|
31 |
| -namespace cinn { |
32 |
| -namespace dialect { |
33 |
| -namespace ir { |
| 34 | +namespace { |
34 | 35 |
|
35 | 36 | std::vector<pir::Value> GetBlockOutsideInput(
|
36 | 37 | const std::vector<pir::Operation*> op_list) {
|
@@ -123,113 +124,93 @@ std::vector<pir::Operation*> GetOutputOpList(
|
123 | 124 | return vec_res;
|
124 | 125 | }
|
125 | 126 |
|
126 |
| -std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) { |
127 |
| - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); |
| 127 | +class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> { |
| 128 | + public: |
| 129 | + using pir::OpRewritePattern<cinn::dialect::GroupOp>::OpRewritePattern; |
128 | 130 |
|
129 |
| - ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>(); |
130 |
| - ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); |
131 |
| - ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>(); |
| 131 | + bool MatchAndRewrite(cinn::dialect::GroupOp group_op, |
| 132 | + pir::PatternRewriter& rewriter) const override { |
| 133 | + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); |
| 134 | + auto target = cinn::common::DefaultNVGPUTarget(); |
| 135 | + auto* program = group_op->GetParentProgram(); |
| 136 | + // TODO(Aurelius84): Remove scope after cleaning PirCompiler usless Build |
| 137 | + // Interface |
| 138 | + auto scope = std::make_shared<cinn::hlir::framework::Scope>(); |
132 | 139 |
|
133 |
| - std::string jit_op_name = cinn::dialect::JitKernelOp::name(); |
134 |
| - ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); |
| 140 | + VLOG(4) << "start Lowering Group Op: " << group_op; |
135 | 141 |
|
136 |
| - auto ir_program = std::make_unique<::pir::Program>(ctx); |
137 |
| - std::unordered_map<pir::Value, pir::Value> value_map; |
| 142 | + // op fusion |
| 143 | + auto op_fusion = cinn::dialect::ir::OpFusionPassInternal( |
| 144 | + GetOpListNotIncludeYield(group_op.ops()), |
| 145 | + GetOutputOpList(group_op.ops())); |
138 | 146 |
|
139 |
| - auto target = cinn::common::DefaultNVGPUTarget(); |
140 |
| - auto scope = cinn::hlir::framework::BuildScope(target, *program); |
| 147 | + // fusion merge |
| 148 | + auto group_list = |
| 149 | + cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion); |
141 | 150 |
|
142 |
| - for (auto it = program->block()->begin(); it != program->block()->end(); |
143 |
| - ++it) { |
144 |
| - if (it->isa<cinn::dialect::GroupOp>()) { |
145 |
| - // GetOpList and Call cinn CodeGen |
146 |
| - auto group_op = it->dyn_cast<cinn::dialect::GroupOp>(); |
| 151 | + for (auto group : group_list) { |
| 152 | + auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>( |
| 153 | + *program, target, scope); |
| 154 | + cinn::hlir::framework::PirCompilerManager::Instance().insert(ir_compiler); |
147 | 155 |
|
148 |
| - // op fusion |
149 |
| - auto op_fusion = cinn::dialect::ir::OpFusionPassInternal( |
150 |
| - GetOpListNotIncludeYield(group_op.ops()), |
151 |
| - GetOutputOpList(group_op.ops())); |
| 156 | + auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); |
| 157 | + std::unordered_map<std::string, ::pir::Attribute> op_attrs{ |
| 158 | + {cinn::dialect::JitKernelOp::kAttrName, |
| 159 | + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, |
| 160 | + }; |
152 | 161 |
|
153 |
| - // fusion merge |
154 |
| - auto group_list = |
155 |
| - cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion); |
| 162 | + // Generate jit kernel op input and output |
| 163 | + auto vec_ins = GetBlockOutsideInput(group->ops); |
156 | 164 |
|
157 |
| - // using yield op to sort |
158 |
| - std::unordered_map<::pir::Value, size_t> value2id; |
159 |
| - auto yeild_op = group_op.ops().back(); |
160 |
| - for (size_t i = 0; i < yeild_op->num_operands(); ++i) { |
161 |
| - value2id[yeild_op->operand_source(i)] = i; |
| 165 | + std::vector<pir::Type> vec_types; |
| 166 | + for (size_t i = 0; i < group->output_values.size(); ++i) { |
| 167 | + vec_types.push_back(group->output_values[i].type()); |
162 | 168 | }
|
163 | 169 |
|
164 |
| - for (auto group : group_list) { |
165 |
| - auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>( |
166 |
| - *program, target, scope); |
167 |
| - hlir::framework::PirCompilerManager::Instance().insert(ir_compiler); |
168 |
| - auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); |
169 |
| - std::unordered_map<std::string, ::pir::Attribute> op_attrs{ |
170 |
| - {cinn::dialect::JitKernelOp::kAttrName, |
171 |
| - cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, |
172 |
| - }; |
173 |
| - |
174 |
| - // Generate jit kernel op input and output |
175 |
| - auto vec_ins = GetBlockOutsideInput(group->ops); |
176 |
| - |
177 |
| - std::vector<pir::Value> vec_new_ins; |
178 |
| - for (size_t i = 0; i < vec_ins.size(); ++i) { |
179 |
| - vec_new_ins.push_back(value_map.at(vec_ins[i])); |
180 |
| - } |
181 |
| - |
182 |
| - std::unordered_map<size_t, size_t> codegen2orig; |
183 |
| - |
184 |
| - std::vector<pir::Type> vec_types; |
185 |
| - for (size_t i = 0; i < group->output_values.size(); ++i) { |
186 |
| - vec_types.push_back(group->output_values[i].type()); |
187 |
| - } |
| 170 | + auto jit_kernel_op = rewriter.Build<cinn::dialect::JitKernelOp>( |
| 171 | + vec_ins, op_attrs, vec_types); |
| 172 | + for (size_t i = 0; i < jit_kernel_op.num_results(); ++i) { |
| 173 | + rewriter.ReplaceAllUsesWith(group->output_values[i], |
| 174 | + jit_kernel_op.result(i)); |
| 175 | + } |
| 176 | + } |
| 177 | + rewriter.EraseOp(group_op); |
| 178 | + return true; |
| 179 | + } |
| 180 | +}; |
188 | 181 |
|
189 |
| - ::pir::Operation* cinn_op = |
190 |
| - ::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info); |
| 182 | +class CinnGroupLoweringPass : public pir::PatternRewritePass { |
| 183 | + public: |
| 184 | + CinnGroupLoweringPass() : pir::PatternRewritePass("cinn_group_lowering", 1) {} |
191 | 185 |
|
192 |
| - for (size_t i = 0; i < cinn_op->num_results(); ++i) { |
193 |
| - auto find_it = value2id.find(group->output_values[i]); |
194 |
| - if (find_it == value2id.end()) { |
195 |
| - value_map[group->output_values[i]] = cinn_op->result(i); |
196 |
| - } else { |
197 |
| - value_map[group_op.result(find_it->second)] = cinn_op->result(i); |
198 |
| - } |
199 |
| - } |
| 186 | + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { |
| 187 | + context->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>(); |
| 188 | + context->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); |
| 189 | + context->GetOrRegisterDialect<paddle::dialect::KernelDialect>(); |
200 | 190 |
|
201 |
| - ir_program->block()->push_back(cinn_op); |
202 |
| - } |
| 191 | + pir::RewritePatternSet ps(context); |
| 192 | + ps.Add<GroupOpPattern>(context); |
203 | 193 |
|
204 |
| - } else { |
205 |
| - std::vector<pir::Value> vec_ins; |
| 194 | + return ps; |
| 195 | + } |
206 | 196 |
|
207 |
| - for (size_t i = 0; i < it->num_operands(); ++i) { |
208 |
| - if (it->operand_source(i)) { |
209 |
| - vec_ins.push_back(value_map.at(it->operand_source(i))); |
210 |
| - } else { |
211 |
| - vec_ins.push_back(it->operand_source(i)); |
212 |
| - } |
213 |
| - } |
| 197 | + bool CanApplyOn(pir::Operation* op) const override { |
| 198 | + return op->isa<pir::ModuleOp>() && op->num_regions() > 0; |
| 199 | + } |
| 200 | +}; |
214 | 201 |
|
215 |
| - std::vector<pir::Type> vec_types; |
216 |
| - for (size_t i = 0; i < it->num_results(); ++i) { |
217 |
| - vec_types.push_back(it->result(i).type()); |
218 |
| - } |
| 202 | +} // namespace |
219 | 203 |
|
220 |
| - ::pir::OpInfo info1 = ctx->GetRegisteredOpInfo(it->name()); |
221 |
| - ::pir::Operation* op = |
222 |
| - ::pir::Operation::Create(vec_ins, it->attributes(), vec_types, info1); |
| 204 | +namespace cinn { |
| 205 | +namespace dialect { |
| 206 | +namespace ir { |
223 | 207 |
|
224 |
| - ir_program->block()->push_back(op); |
225 |
| - for (size_t i = 0; i < it->num_results(); ++i) { |
226 |
| - value_map[it->result(i)] = op->result(i); |
227 |
| - } |
228 |
| - } |
229 |
| - } |
230 |
| - return ir_program; |
| 208 | +std::unique_ptr<::pir::Pass> CreateCinnGroupLoweringPass() { |
| 209 | + return std::make_unique<CinnGroupLoweringPass>(); |
231 | 210 | }
|
232 | 211 |
|
233 | 212 | } // namespace ir
|
234 | 213 | } // namespace dialect
|
235 | 214 | } // namespace cinn
|
| 215 | + |
| 216 | +REGISTER_IR_PASS(cinn_group_lowering, CinnGroupLoweringPass); |
0 commit comments