@@ -95,11 +95,24 @@ CompositeTypes GetArgReduceUnderlyingType(const ir::Expr& expr) {
95
95
96
96
void SetInitValue (Store store_stmt,
97
97
common::Type new_type,
98
- const CompositeTypes& comp_type) {
98
+ const CompositeTypes& comp_type,
99
+ std::string prefix = " " ) {
100
+ // prefix: if target is x86, we can not call constructor for POD struct
101
+ // the intrinsic function for creating struct is usually "create_" + typename
99
102
ir::Expr init_value = store_stmt->value ();
103
+ auto call_op = init_value.As <ir::Call>();
104
+ // if the type is already a call
105
+ if (call_op != nullptr ) {
106
+ call_op->set_type (new_type);
107
+ if (call_op->name .find (" argidx_" ) == 0 ||
108
+ call_op->name .find (" welford_" ) == 0 ) {
109
+ call_op->name = prefix + call_op->name ;
110
+ }
111
+ return ;
112
+ }
100
113
if (comp_type.type == ReduceType::kVariance ) {
101
114
store_stmt->set_value (ir::Call::Make (new_type,
102
- new_type.customized_type (),
115
+ prefix + new_type.customized_type (),
103
116
{init_value, init_value, init_value},
104
117
{},
105
118
ir::CallType::Intrinsic));
@@ -111,7 +124,7 @@ void SetInitValue(Store store_stmt,
111
124
index_init->set_type (common::Int (64 ));
112
125
}
113
126
store_stmt->set_value (ir::Call::Make (new_type,
114
- new_type.customized_type (),
127
+ prefix + new_type.customized_type (),
115
128
{init_value, index_init},
116
129
{},
117
130
ir::CallType::Intrinsic));
@@ -199,6 +212,60 @@ std::map<ir::Buffer, CompositeTypes> CollectTypedReduceBuffers(
199
212
return typed_buffers;
200
213
}
201
214
215
+ void ReplaceOutputBufferX86 (
216
+ const BlockRef& body,
217
+ const std::set<ir::Buffer>& out_buffer_map,
218
+ const std::map<ir::Buffer, CompositeTypes>& typed_buffers) {
219
+ // re-route the reduce_init buffer to the local staging buffer
220
+ // and set the type for the buffers correctly
221
+ struct BufferRelationRecorder {
222
+ Store reduce_init;
223
+ Store write_back;
224
+ };
225
+ std::map<ir::Buffer, BufferRelationRecorder> buffer_relations;
226
+ for (auto buffer : out_buffer_map) {
227
+ buffer_relations.emplace (buffer, BufferRelationRecorder ());
228
+ }
229
+ const auto VisitFn = [&](const StmtRef& stmt) {
230
+ if (!stmt.isa <Store>()) return ;
231
+ Store store_stmt = stmt.as <Store>();
232
+
233
+ auto * tensor = store_stmt->tensor ().as_tensor ();
234
+ auto & buffer = tensor->buffer ;
235
+ auto buffer_it = buffer_relations.find (buffer);
236
+ // check whether the buffer is related to output args
237
+ if (buffer_it == buffer_relations.end ()) return ;
238
+ if (ir::IsReduceInitTensorName (tensor->name )) {
239
+ buffer_it->second .reduce_init = store_stmt;
240
+ } else {
241
+ buffer_it->second .write_back = store_stmt;
242
+ }
243
+ };
244
+
245
+ ir::stmt::Visit (body, VisitFn, [](auto ) {});
246
+
247
+ for (auto & [_, buffer_rel] : buffer_relations) {
248
+ // both should be defined
249
+ if (!buffer_rel.reduce_init .defined () || !buffer_rel.write_back .defined ())
250
+ continue ;
251
+ auto wb_value = buffer_rel.write_back ->value ();
252
+ if (auto load_node = wb_value.As <ir::Load>()) {
253
+ auto wb_load_buffer = load_node->tensor .as_tensor ()->buffer ;
254
+ auto wb_load_it = typed_buffers.find (wb_load_buffer);
255
+ PADDLE_ENFORCE_NE (wb_load_it,
256
+ typed_buffers.end (),
257
+ ::common::errors::Fatal (
258
+ " Buffer '%s' should be defined in typed_buffers." ,
259
+ wb_load_buffer->name));
260
+ // set the buffer of the reduce_init to write back buffer
261
+ ir::Expr new_tensor =
262
+ ir::ir_utils::IRCopy (buffer_rel.reduce_init ->tensor ());
263
+ new_tensor.as_tensor ()->buffer = wb_load_buffer;
264
+ buffer_rel.reduce_init ->set_tensor (new_tensor);
265
+ }
266
+ }
267
+ }
268
+
202
269
Store GetStoreOfSchedule (const Schedule& stmt) {
203
270
Store store_stmt;
204
271
bool found = false ;
@@ -406,7 +473,8 @@ struct LoadTypeMutator : public ir::IRMutator<> {
406
473
};
407
474
408
475
void SetBufferType (ir::LoweredFunc func,
409
- const std::map<ir::Buffer, CompositeTypes>& typed_buffers) {
476
+ const std::map<ir::Buffer, CompositeTypes>& typed_buffers,
477
+ bool is_x86_arch) {
410
478
// Make a map from the buffers to their element and composite reduce types,
411
479
// otherwise it's hard to know a buffer's original type. The original type
412
480
// must be known to perform casting (back) in LoadTypeMutator::Visit()
@@ -439,9 +507,10 @@ void SetBufferType(ir::LoweredFunc func,
439
507
new_tensor.as_tensor ()->set_type (new_type);
440
508
new_tensor.as_tensor ()->buffer ->dtype = new_type;
441
509
store_stmt->set_tensor (new_tensor);
442
-
443
- if (ir::IsReduceInitTensorName (tensor->name )) {
444
- SetInitValue (store_stmt, new_type, composite_type);
510
+ stmt->set_type (new_type);
511
+ if (ir::IsReduceInitTensorName (new_tensor.as_tensor ()->name )) {
512
+ std::string call_prefix = is_x86_arch ? " create_" : " " ;
513
+ SetInitValue (store_stmt, new_type, composite_type, call_prefix);
445
514
}
446
515
}
447
516
@@ -490,12 +559,72 @@ struct ReduceExternCallMutator : public ir::IRMutator<> {
490
559
}
491
560
};
492
561
493
- void ReplaceReduceExternCall (const BlockRef& body) {
562
+ struct ReduceExternCallMutatorX86 : public ir ::IRMutator<> {
563
+ // unlike non x86 counterpart, we do not replace the call
564
+ // by a arithmetic IR node, but instead call x86-exclusive funcs
565
+ void operator ()(ir::Expr* expr) { ir::IRMutator<>::Visit (expr, expr); }
566
+
567
+ private:
568
+ void Visit (const ir::Call* op, ir::Expr* expr) override {
569
+ ir::IRMutator<>::Visit (op, expr);
570
+ auto reduce_type_ = GetReduceType (*expr);
571
+ if (reduce_type_ == ReduceType::kNone ) return ;
572
+ ir::Expr lhs = op->read_args [0 ];
573
+ ir::Expr rhs = op->read_args [1 ];
574
+ std::string lhs_type = lhs.type ().to_string ();
575
+ if (lhs.type () != rhs.type ()) {
576
+ if (auto call_op = rhs.As <ir::Call>()) {
577
+ // for argidx type, avoid redundant type casting, but this is ugly
578
+ if (call_op->name .find (" argidx" ) == 0 ) {
579
+ call_op->name = " create_" + call_op->name ;
580
+ rhs->set_type (lhs.type ());
581
+ }
582
+ } else {
583
+ // welford pod type call create function on x86
584
+ ir::Expr m2_init (0 .f ), weight_init (1 .f );
585
+ if (lhs_type == " welford_fp64" ) {
586
+ m2_init->set_type (common::F64 ());
587
+ weight_init->set_type (common::F64 ());
588
+ }
589
+ rhs = ir::Call::Make (lhs.type (),
590
+ " create_" + lhs_type,
591
+ {rhs, m2_init, weight_init},
592
+ {},
593
+ ir::CallType::Intrinsic);
594
+ }
595
+ }
596
+ std::string call_prefix = " " ;
597
+ switch (reduce_type_) {
598
+ case ReduceType::kVariance :
599
+ call_prefix = " sum_" ;
600
+ break ;
601
+ case ReduceType::kArgmax :
602
+ call_prefix = " max_" ;
603
+ break ;
604
+ case ReduceType::kArgmin :
605
+ call_prefix = " min_" ;
606
+ break ;
607
+ default :
608
+ break ;
609
+ }
610
+ *expr = ir::Call::Make (lhs.type (),
611
+ call_prefix + lhs_type,
612
+ {lhs, rhs},
613
+ {},
614
+ ir::CallType::Intrinsic);
615
+ }
616
+ };
617
+
618
+ void ReplaceReduceExternCall (const BlockRef& body, bool is_x86_arch = false ) {
494
619
const auto VisitFn = [&](StmtRef stmt) {
495
620
if (!stmt.isa <Store>()) return ;
496
621
Store store_stmt = stmt.as <Store>();
497
622
ir::Expr new_value = ir::ir_utils::IRCopy (store_stmt->value ());
498
- ReduceExternCallMutator ()(&new_value);
623
+ if (is_x86_arch) {
624
+ ReduceExternCallMutatorX86 ()(&new_value);
625
+ } else {
626
+ ReduceExternCallMutator ()(&new_value);
627
+ }
499
628
store_stmt->set_value (new_value);
500
629
};
501
630
@@ -528,18 +657,48 @@ LogicalResult RealizeCompositeReducePass::Run(ir::LoweredFunc func) {
528
657
typed_buffers = ResolveUndefinedArgIdxType (std::move (typed_buffers),
529
658
std::move (arg_stores));
530
659
660
+ bool is_x86_arch = false ;
661
+ target_.arch .Match (
662
+ [&](std::variant<common::X86Arch>) {
663
+ /* *
664
+ * trace the CPU buffer for reduce init. For x86 pass, schedule pass
665
+ * will not be applied, therefore, the reduce_init buffer will be the
666
+ * same as the output buffer, which leads to incorrect buffer type and
667
+ * op type for codegen
668
+ *
669
+ * (1) we first extract the buffer for each output arg
670
+ * (2) find all stores to the corresponding output buffer, this op is
671
+ * prior to the output type cast, for x86 IR, reduce_init and the
672
+ * writing back op uses the same buffer (output tensor buffer). (3)
673
+ * create a mapping. if the buffer of a store (the value of the store)
674
+ * is in the typed_buffer, we try finding the reduce_init related op,
675
+ * and change the the buffer and op type of the reduce_init
676
+ */
677
+ is_x86_arch = true ;
678
+ std::set<ir::Buffer> output_buffers;
679
+ for (auto & arg : func->args ) {
680
+ if (!arg.is_output ()) continue ;
681
+ output_buffers.emplace (arg.buffer_arg ());
682
+ }
683
+ ReplaceOutputBufferX86 (body, output_buffers, typed_buffers);
684
+ },
685
+ [&](std::variant<common::NVGPUArch,
686
+ common::HygonDCUArchHIP,
687
+ common::HygonDCUArchSYCL,
688
+ common::ARMArch,
689
+ common::UnknownArch>) {});
531
690
// Step 3. Change the data type of buffers to the corresponding type.
532
- SetBufferType (func, typed_buffers);
691
+ SetBufferType (func, typed_buffers, is_x86_arch );
533
692
534
693
// Step 4. Replace the `cinn_reduce_variance` and `cinn_argmax` calls
535
694
// in order to reuse the cross-thread/block reduction templates.
536
- ReplaceReduceExternCall (body);
695
+ ReplaceReduceExternCall (body, is_x86_arch );
537
696
538
697
return LogicalResult::success ();
539
698
}
540
699
541
- std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass () {
542
- return std::make_unique<RealizeCompositeReducePass>();
700
+ std::unique_ptr<FuncPass> CreateRealizeCompositeReducePass (Target target ) {
701
+ return std::make_unique<RealizeCompositeReducePass>(target );
543
702
}
544
703
545
704
} // namespace optim
0 commit comments