Skip to content

Commit dfc6d4c

Browse files
authored
[CINN] PostProcess supports x86 composite reduce (PaddlePaddle#72370)
1 parent c1fb672 commit dfc6d4c

File tree

3 files changed

+184
-18
lines changed

3 files changed

+184
-18
lines changed

paddle/cinn/optim/optimize.cc

+6-3
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
7272

7373
{
7474
FuncPassManager func_pass_manager;
75-
func_pass_manager.AddPass(CreateRealizeCompositeReducePass());
75+
func_pass_manager.AddPass(CreateRealizeCompositeReducePass(target));
7676
func_pass_manager.AddPass(CreateReindexTransposeBufferPass());
7777
func_pass_manager.Run(copied);
78-
VLOG(4) << "After Optimize CustomizedReduce and ReindexTransposeBuffer: "
78+
VLOG(4) << "After Optimize CompositeReducePass and ReindexTransposeBuffer: "
7979
<< copied;
8080
}
8181

@@ -187,7 +187,10 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
187187

188188
LowerIntrin(&copied->body, target);
189189
VLOG(10) << "After LowerIntrin:" << copied;
190-
190+
// re-compute buffer cast exprs since
191+
// x86 codegen needs correct buffer types to generate
192+
// symbol table
193+
copied->PrepareBufferCastExprs(false);
191194
return copied;
192195
}
193196

paddle/cinn/optim/realize_composite_reduce_pass.cc

+172-13
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,24 @@ CompositeTypes GetArgReduceUnderlyingType(const ir::Expr& expr) {
9595

9696
void SetInitValue(Store store_stmt,
9797
common::Type new_type,
98-
const CompositeTypes& comp_type) {
98+
const CompositeTypes& comp_type,
99+
std::string prefix = "") {
100+
// prefix: if target is x86, we can not call constructor for POD struct
101+
// the intrinsic function for creating struct is usually "create_" + typename
99102
ir::Expr init_value = store_stmt->value();
103+
auto call_op = init_value.As<ir::Call>();
104+
// if the type is already a call
105+
if (call_op != nullptr) {
106+
call_op->set_type(new_type);
107+
if (call_op->name.find("argidx_") == 0 ||
108+
call_op->name.find("welford_") == 0) {
109+
call_op->name = prefix + call_op->name;
110+
}
111+
return;
112+
}
100113
if (comp_type.type == ReduceType::kVariance) {
101114
store_stmt->set_value(ir::Call::Make(new_type,
102-
new_type.customized_type(),
115+
prefix + new_type.customized_type(),
103116
{init_value, init_value, init_value},
104117
{},
105118
ir::CallType::Intrinsic));
@@ -111,7 +124,7 @@ void SetInitValue(Store store_stmt,
111124
index_init->set_type(common::Int(64));
112125
}
113126
store_stmt->set_value(ir::Call::Make(new_type,
114-
new_type.customized_type(),
127+
prefix + new_type.customized_type(),
115128
{init_value, index_init},
116129
{},
117130
ir::CallType::Intrinsic));
@@ -199,6 +212,60 @@ std::map<ir::Buffer, CompositeTypes> CollectTypedReduceBuffers(
199212
return typed_buffers;
200213
}
201214

215+
void ReplaceOutputBufferX86(
216+
const BlockRef& body,
217+
const std::set<ir::Buffer>& out_buffer_map,
218+
const std::map<ir::Buffer, CompositeTypes>& typed_buffers) {
219+
// re-route the reduce_init buffer to the local staging buffer
220+
// and set the type for the buffers correctly
221+
struct BufferRelationRecorder {
222+
Store reduce_init;
223+
Store write_back;
224+
};
225+
std::map<ir::Buffer, BufferRelationRecorder> buffer_relations;
226+
for (auto buffer : out_buffer_map) {
227+
buffer_relations.emplace(buffer, BufferRelationRecorder());
228+
}
229+
const auto VisitFn = [&](const StmtRef& stmt) {
230+
if (!stmt.isa<Store>()) return;
231+
Store store_stmt = stmt.as<Store>();
232+
233+
auto* tensor = store_stmt->tensor().as_tensor();
234+
auto& buffer = tensor->buffer;
235+
auto buffer_it = buffer_relations.find(buffer);
236+
// check whether the buffer is related to output args
237+
if (buffer_it == buffer_relations.end()) return;
238+
if (ir::IsReduceInitTensorName(tensor->name)) {
239+
buffer_it->second.reduce_init = store_stmt;
240+
} else {
241+
buffer_it->second.write_back = store_stmt;
242+
}
243+
};
244+
245+
ir::stmt::Visit(body, VisitFn, [](auto) {});
246+
247+
for (auto& [_, buffer_rel] : buffer_relations) {
248+
// both should be defined
249+
if (!buffer_rel.reduce_init.defined() || !buffer_rel.write_back.defined())
250+
continue;
251+
auto wb_value = buffer_rel.write_back->value();
252+
if (auto load_node = wb_value.As<ir::Load>()) {
253+
auto wb_load_buffer = load_node->tensor.as_tensor()->buffer;
254+
auto wb_load_it = typed_buffers.find(wb_load_buffer);
255+
PADDLE_ENFORCE_NE(wb_load_it,
256+
typed_buffers.end(),
257+
::common::errors::Fatal(
258+
"Buffer '%s' should be defined in typed_buffers.",
259+
wb_load_buffer->name));
260+
// set the buffer of the reduce_init to write back buffer
261+
ir::Expr new_tensor =
262+
ir::ir_utils::IRCopy(buffer_rel.reduce_init->tensor());
263+
new_tensor.as_tensor()->buffer = wb_load_buffer;
264+
buffer_rel.reduce_init->set_tensor(new_tensor);
265+
}
266+
}
267+
}
268+
202269
Store GetStoreOfSchedule(const Schedule& stmt) {
203270
Store store_stmt;
204271
bool found = false;
@@ -406,7 +473,8 @@ struct LoadTypeMutator : public ir::IRMutator<> {
406473
};
407474

408475
void SetBufferType(ir::LoweredFunc func,
409-
const std::map<ir::Buffer, CompositeTypes>& typed_buffers) {
476+
const std::map<ir::Buffer, CompositeTypes>& typed_buffers,
477+
bool is_x86_arch) {
410478
// Make a map from the buffers to their element and composite reduce types,
411479
// otherwise it's hard to know a buffer's original type. The original type
412480
// must be known to perform casting (back) in LoadTypeMutator::Visit()
@@ -439,9 +507,10 @@ void SetBufferType(ir::LoweredFunc func,
439507
new_tensor.as_tensor()->set_type(new_type);
440508
new_tensor.as_tensor()->buffer->dtype = new_type;
441509
store_stmt->set_tensor(new_tensor);
442-
443-
if (ir::IsReduceInitTensorName(tensor->name)) {
444-
SetInitValue(store_stmt, new_type, composite_type);
510+
stmt->set_type(new_type);
511+
if (ir::IsReduceInitTensorName(new_tensor.as_tensor()->name)) {
512+
std::string call_prefix = is_x86_arch ? "create_" : "";
513+
SetInitValue(store_stmt, new_type, composite_type, call_prefix);
445514
}
446515
}
447516

@@ -490,12 +559,72 @@ struct ReduceExternCallMutator : public ir::IRMutator<> {
490559
}
491560
};
492561

493-
void ReplaceReduceExternCall(const BlockRef& body) {
562+
struct ReduceExternCallMutatorX86 : public ir::IRMutator<> {
563+
// unlike non x86 counterpart, we do not replace the call
564+
// by a arithmetic IR node, but instead call x86-exclusive funcs
565+
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
566+
567+
private:
568+
void Visit(const ir::Call* op, ir::Expr* expr) override {
569+
ir::IRMutator<>::Visit(op, expr);
570+
auto reduce_type_ = GetReduceType(*expr);
571+
if (reduce_type_ == ReduceType::kNone) return;
572+
ir::Expr lhs = op->read_args[0];
573+
ir::Expr rhs = op->read_args[1];
574+
std::string lhs_type = lhs.type().to_string();
575+
if (lhs.type() != rhs.type()) {
576+
if (auto call_op = rhs.As<ir::Call>()) {
577+
// for argidx type, avoid redundant type casting, but this is ugly
578+
if (call_op->name.find("argidx") == 0) {
579+
call_op->name = "create_" + call_op->name;
580+
rhs->set_type(lhs.type());
581+
}
582+
} else {
583+
// welford pod type call create function on x86
584+
ir::Expr m2_init(0.f), weight_init(1.f);
585+
if (lhs_type == "welford_fp64") {
586+
m2_init->set_type(common::F64());
587+
weight_init->set_type(common::F64());
588+
}
589+
rhs = ir::Call::Make(lhs.type(),
590+
"create_" + lhs_type,
591+
{rhs, m2_init, weight_init},
592+
{},
593+
ir::CallType::Intrinsic);
594+
}
595+
}
596+
std::string call_prefix = "";
597+
switch (reduce_type_) {
598+
case ReduceType::kVariance:
599+
call_prefix = "sum_";
600+
break;
601+
case ReduceType::kArgmax:
602+
call_prefix = "max_";
603+
break;
604+
case ReduceType::kArgmin:
605+
call_prefix = "min_";
606+
break;
607+
default:
608+
break;
609+
}
610+
*expr = ir::Call::Make(lhs.type(),
611+
call_prefix + lhs_type,
612+
{lhs, rhs},
613+
{},
614+
ir::CallType::Intrinsic);
615+
}
616+
};
617+
618+
void ReplaceReduceExternCall(const BlockRef& body, bool is_x86_arch = false) {
494619
const auto VisitFn = [&](StmtRef stmt) {
495620
if (!stmt.isa<Store>()) return;
496621
Store store_stmt = stmt.as<Store>();
497622
ir::Expr new_value = ir::ir_utils::IRCopy(store_stmt->value());
498-
ReduceExternCallMutator()(&new_value);
623+
if (is_x86_arch) {
624+
ReduceExternCallMutatorX86()(&new_value);
625+
} else {
626+
ReduceExternCallMutator()(&new_value);
627+
}
499628
store_stmt->set_value(new_value);
500629
};
501630

@@ -528,18 +657,48 @@ LogicalResult RealizeCompositeReducePass::Run(ir::LoweredFunc func) {
528657
typed_buffers = ResolveUndefinedArgIdxType(std::move(typed_buffers),
529658
std::move(arg_stores));
530659

660+
bool is_x86_arch = false;
661+
target_.arch.Match(
662+
[&](std::variant<common::X86Arch>) {
663+
/**
664+
* trace the CPU buffer for reduce init. For x86 pass, schedule pass
665+
* will not be applied, therefore, the reduce_init buffer will be the
666+
* same as the output buffer, which leads to incorrect buffer type and
667+
* op type for codegen
668+
*
669+
* (1) we first extract the buffer for each output arg
670+
* (2) find all stores to the corresponding output buffer, this op is
671+
* prior to the output type cast, for x86 IR, reduce_init and the
672+
* writing back op uses the same buffer (output tensor buffer). (3)
673+
* create a mapping. if the buffer of a store (the value of the store)
674+
* is in the typed_buffer, we try finding the reduce_init related op,
675+
* and change the the buffer and op type of the reduce_init
676+
*/
677+
is_x86_arch = true;
678+
std::set<ir::Buffer> output_buffers;
679+
for (auto& arg : func->args) {
680+
if (!arg.is_output()) continue;
681+
output_buffers.emplace(arg.buffer_arg());
682+
}
683+
ReplaceOutputBufferX86(body, output_buffers, typed_buffers);
684+
},
685+
[&](std::variant<common::NVGPUArch,
686+
common::HygonDCUArchHIP,
687+
common::HygonDCUArchSYCL,
688+
common::ARMArch,
689+
common::UnknownArch>) {});
531690
// Step 3. Change the data type of buffers to the corresponding type.
532-
SetBufferType(func, typed_buffers);
691+
SetBufferType(func, typed_buffers, is_x86_arch);
533692

534693
// Step 4. Replace the `cinn_reduce_variance` and `cinn_argmax` calls
535694
// in order to reuse the cross-thread/block reduction templates.
536-
ReplaceReduceExternCall(body);
695+
ReplaceReduceExternCall(body, is_x86_arch);
537696

538697
return LogicalResult::success();
539698
}
540699

541-
std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass() {
542-
return std::make_unique<RealizeCompositeReducePass>();
700+
std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass(Target target) {
701+
return std::make_unique<RealizeCompositeReducePass>(target);
543702
}
544703

545704
} // namespace optim

paddle/cinn/optim/realize_composite_reduce_pass.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,15 @@ namespace optim {
7676
*/
7777
class RealizeCompositeReducePass : public FuncPass {
7878
public:
79-
RealizeCompositeReducePass() : FuncPass("realize_composite_reduce") {}
79+
explicit RealizeCompositeReducePass(Target target)
80+
: FuncPass("realize_composite_reduce"), target_(target) {}
8081
LogicalResult Run(ir::LoweredFunc func) override;
82+
83+
private:
84+
const Target target_;
8185
};
8286

83-
std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass();
87+
std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass(Target target);
8488

8589
} // namespace optim
8690
} // namespace cinn

0 commit comments

Comments
 (0)