48
48
49
49
PD_DECLARE_bool (cinn_use_cuda_vectorize);
50
50
PD_DECLARE_bool (cinn_check_tensor_buffer_map);
51
+ PD_DECLARE_bool (cinn_longlong2int);
51
52
const int default_priority = 100 ;
52
53
53
54
namespace cinn {
@@ -195,49 +196,48 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
195
196
// including preparing function args and temporary variables,
196
197
// applying low-level optimization passes, etc.
197
198
std::vector<ir::Expr> scheduled_func_bodies;
199
+ std::vector<ir::SymbolicPredicate> predicates;
198
200
for (std::pair<ir::SymbolicPredicate, ir::Expr>& cond2body :
199
201
cond2func_bodies) {
202
+ predicates.push_back (cond2body.first );
200
203
scheduled_func_bodies.push_back (cond2body.second );
201
204
}
202
205
std::vector<ir::Tensor> group_func_arg_tensors_copy = group_func_arg_tensors;
203
206
std::vector<ir::Argument> group_func_args;
204
207
std::vector<ir::Tensor> infer_shape_tensor_args;
205
- std::vector<ir::LoweredFunc> funcs = PostProcess (group,
206
- tensor_map,
207
- {scheduled_func_bodies},
208
- &group_func_arg_tensors_copy,
209
- &group_func_args,
210
- &infer_shape_tensor_args);
208
+
209
+ std::vector<CondFuncPriorWrapper> warps_processed =
210
+ PostProcess (group,
211
+ tensor_map,
212
+ fusion_group_info,
213
+ {scheduled_func_bodies},
214
+ {predicates},
215
+ {priorities},
216
+ &group_func_arg_tensors_copy,
217
+ &group_func_args,
218
+ &infer_shape_tensor_args);
211
219
if (FLAGS_cinn_check_tensor_buffer_map) {
212
- for (ir::LoweredFunc& func : funcs) {
213
- optim::CheckTensorBufferMap (func->body , " BucketLower PostProcess" );
220
+ for (auto & warp : warps_processed) {
221
+ optim::CheckTensorBufferMap (std::get<1 >(warp)->body ,
222
+ " BucketLower PostProcess" );
214
223
}
215
224
VLOG (3 ) << " PostProcess tensor-buffer map check succeed" ;
216
225
}
217
- PADDLE_ENFORCE_EQ (funcs.size (),
218
- cond2func_bodies.size (),
219
- ::common::errors::InvalidArgument (
220
- " The size of funcs and cond2func_bodies should be "
221
- " the same." ));
222
- PADDLE_ENFORCE_EQ (funcs.size (),
223
- priorities.size () + 1 ,
224
- ::common::errors::InvalidArgument (
225
- " The size of funcs should equals to the "
226
- " size of priorities plus one." ));
226
+
227
227
BucketLoweredFuncsWrapper funcs_wrapper;
228
- for (int i = 0 ; i < funcs.size () - 1 ; ++i) {
229
- funcs_wrapper.predicate2funcs .emplace_back (
230
- std::make_tuple (cond2func_bodies[i].first , funcs[i], priorities[i]));
228
+ for (int i = 0 ; i < warps_processed.size () - 1 ; ++i) {
229
+ funcs_wrapper.predicate2funcs .emplace_back (warps_processed[i]);
231
230
}
231
+
232
232
// The last func is x86 kernel.
233
- for (size_t i = funcs.size () - 1 ; i < funcs.size (); ++i) {
234
- if (funcs[i]->body == ir::Expr (-1 )) {
235
- continue ;
236
- }
237
- funcs[i]->name = funcs[i]->name + " _CX86" ;
238
- funcs_wrapper.predicate2funcsCX86 .emplace_back (cond2func_bodies[i].first ,
239
- funcs[i]);
233
+ auto [predicate_postprocessed, func_postprocessed, _] =
234
+ warps_processed[warps_processed.size () - 1 ];
235
+ if (func_postprocessed->body != ir::Expr (-1 )) {
236
+ func_postprocessed->name = func_postprocessed->name + " _CX86" ;
237
+ funcs_wrapper.predicate2funcsCX86 .emplace_back (predicate_postprocessed,
238
+ func_postprocessed);
240
239
}
240
+
241
241
funcs_wrapper.infer_shape_func =
242
242
GenerateInferShapeFunc (group, infer_shape_tensor_args, group_func_args);
243
243
@@ -258,13 +258,18 @@ std::unordered_set<std::string> CollectStoreBufferNames(
258
258
return buffer_names;
259
259
}
260
260
261
- std::vector<ir::LoweredFunc > OpLowererImpl::PostProcess (
261
+ std::vector<CondFuncPriorWrapper > OpLowererImpl::PostProcess (
262
262
const OpLoweringGroupPtr& group,
263
263
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
264
+ const std::shared_ptr<FusionGroupInfo>& fusion_group_info,
264
265
std::vector<ir::Expr> func_bodies,
266
+ std::vector<ir::SymbolicPredicate> predicates,
267
+ std::vector<int > priorities,
265
268
std::vector<ir::Tensor>* group_func_arg_tensors,
266
269
std::vector<ir::Argument>* group_func_args,
267
270
std::vector<ir::Tensor>* infer_shape_arg_tensor) {
271
+ std::vector<ir::Expr> inputs_element_size;
272
+
268
273
// 1.Prepare function args
269
274
group->mut_input_names ().clear ();
270
275
std::unordered_set<std::string> store_buffer_names =
@@ -280,6 +285,12 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
280
285
? ir::Argument::IO::kOutput
281
286
: ir::Argument::IO::kInput ;
282
287
(*group_func_args).emplace_back (arg_tensor->buffer , io_type);
288
+ // collect element size for longlong2int pass.
289
+ if (FLAGS_cinn_longlong2int) {
290
+ inputs_element_size.push_back (common::FoldExpr (
291
+ [](const Expr& a, const Expr& b) { return ir::Mul::Make (a, b); },
292
+ arg_tensor->shape ));
293
+ }
283
294
arg_name_set.insert (arg_tensor->buffer ->name );
284
295
}
285
296
@@ -330,6 +341,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
330
341
std::map<int , CINNKernelInfo::SymbolArgBindInfo> mps;
331
342
// update args for dynamic dim
332
343
int non_tensor_arg_idx = group_func_args->size ();
344
+
333
345
std::unordered_set<std::string> symbol_args_set;
334
346
for (int tensor_arg_idx = 0 ; tensor_arg_idx < input_tensor_size;
335
347
tensor_arg_idx++) {
@@ -381,7 +393,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
381
393
AddDimSymbolArgs ();
382
394
AddValueSymbolArgs ();
383
395
}
384
- std::vector<ir::LoweredFunc> lowered_funcs;
396
+
397
+ std::vector<ir::LoweredFunc> ret_lowered_funcs;
398
+ std::vector<ir::SymbolicPredicate> ret_predicates;
399
+ std::vector<int > ret_priorities;
400
+
385
401
for (int i = 0 ; i < func_bodies.size (); ++i) {
386
402
ir::Expr func_body = func_bodies[i];
387
403
optim::EliminateDeadScheduleBlock (&(func_body), group->output_names ());
@@ -416,14 +432,53 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
416
432
func = optim::Optimize (func, common::DefaultHostTarget (), false );
417
433
}
418
434
func->num_output_tensors = infer_shape_arg_tensor->size ();
419
- lowered_funcs.push_back (std::move (func));
420
- }
421
435
422
- // 5. Unify temp_space args and set temp_space sizes
423
- UnifyTempSpaceArgs (&lowered_funcs);
424
- group->mut_temp_space_sizes () = CollectTempSpaceSizes (lowered_funcs);
436
+ // 5. Apply longlong2int pass
437
+ if (i != func_bodies.size () - 1 ) {
438
+ LongLong2Int (symbol_args_set,
439
+ fusion_group_info->loop_ranges_expr ,
440
+ inputs_element_size,
441
+ priorities[i],
442
+ &predicates[i],
443
+ &func,
444
+ &ret_predicates,
445
+ &ret_lowered_funcs,
446
+ &ret_priorities);
447
+ }
448
+ ret_predicates.push_back (std::move (predicates[i]));
449
+ ret_lowered_funcs.push_back (std::move (func));
450
+ // host func has no priority, since tuples require alignment, set -1 here.
451
+ if (i != func_bodies.size () - 1 ) {
452
+ ret_priorities.push_back (std::move (priorities[i]));
453
+ } else {
454
+ ret_priorities.push_back (-1 );
455
+ }
456
+ }
425
457
426
- return lowered_funcs;
458
+ // 6. Unify temp_space args and set temp_space sizes
459
+ UnifyTempSpaceArgs (&ret_lowered_funcs);
460
+ group->mut_temp_space_sizes () = CollectTempSpaceSizes (ret_lowered_funcs);
461
+
462
+ PADDLE_ENFORCE_EQ (
463
+ ret_lowered_funcs.size (),
464
+ ret_predicates.size (),
465
+ ::common::errors::InvalidArgument (
466
+ " The size of ret_lowered_funcs and ret_predicates should be "
467
+ " the same." ));
468
+ PADDLE_ENFORCE_EQ (
469
+ ret_lowered_funcs.size (),
470
+ ret_priorities.size (),
471
+ ::common::errors::InvalidArgument (
472
+ " The size of ret_lowered_funcs and ret_priorities should be "
473
+ " the same." ));
474
+
475
+ std::vector<CondFuncPriorWrapper> ret;
476
+ for (size_t i = 0 ; i < ret_lowered_funcs.size (); ++i) {
477
+ ret.emplace_back (std::move (ret_predicates[i]),
478
+ std::move (ret_lowered_funcs[i]),
479
+ std::move (ret_priorities[i]));
480
+ }
481
+ return ret;
427
482
}
428
483
429
484
std::vector<ir::stmt::BlockRef> OpLowererImpl::LowerOps (
0 commit comments