Skip to content

Commit 64d227c

Browse files
authored
【CINN】longlong2int for dynamic shape (#71072)
* longlong2int for dynamic shape * change cuda func args type * add args for grid reduce * fix bug for ci * remove ! in func name * refine code * ir copy on host module args * update dynamic cast * fix comment * polish code * refine code
1 parent 6317290 commit 64d227c

12 files changed

+353
-75
lines changed

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,24 @@ detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(
170170
std::string cond_str = Predicate2String(predicate);
171171
// replace '-' with 'NEG'
172172
size_t pos = cond_str.find("-", 0);
173-
const std::string replacement = "NEG";
173+
const std::string replacement_neg = "NEG";
174174
while (pos != std::string::npos) {
175-
cond_str.replace(pos, 1, replacement);
176-
pos = cond_str.find("-", pos + replacement.length());
175+
cond_str.replace(pos, 1, replacement_neg);
176+
pos = cond_str.find("-", pos + replacement_neg.length());
177+
}
178+
179+
// replace '!' with 'NOT'
180+
pos = cond_str.find("!", 0);
181+
const std::string replacement_not = "NOT";
182+
while (pos != std::string::npos) {
183+
cond_str.replace(pos, 1, replacement_not);
184+
pos = cond_str.find("!", pos + replacement_not.length());
177185
}
178186
VLOG(3) << "predicate string: " << cond_str;
179187
// NOTE(chenxi67): The kernel name is too long to be supported in cuda12.3 so
180188
// we need to curtail it.
181189
const std::string new_fn_name = CurTailFnName(fn_name);
182-
return new_fn_name + "__COND_" + cond_str + "__kernel";
190+
return new_fn_name + "_COND_" + cond_str + "__kernel";
183191
}
184192

185193
void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
@@ -245,19 +253,33 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
245253
call_kernel = runtime::intrinsic::call_sycl_kernel;
246254
});
247255
// TODO(Dmovic): use new ir when backend update done.
256+
// Author(liujinnan): Copy args instead of use func args directly in host
257+
// func. because after longlong2int pass, some type of loweredfunc args may be
258+
// changed to int32, it cause compile error when lower to LLVM IR.
259+
std::vector<ir::Expr> kernel_args_int64 = {
260+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.grid_dim(0)),
261+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.grid_dim(1)),
262+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.grid_dim(2)),
263+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.block_dim(0)),
264+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.block_dim(1)),
265+
ir::ir_utils::IRCopy(func_node->cuda_axis_info.block_dim(2)),
266+
ir::ir_utils::IRCopy(shared_mem_bytes.value()),
267+
cinn::common::make_const(Int(64), 0) /* enable TryElevateInt32ToInt64 */};
268+
ir::TryElevateInt32ToInt64(kernel_args_int64);
269+
248270
ir::Expr call_extern_api =
249271
ir::Call::Make(Void(),
250272
call_kernel.value(),
251273
{kernel_ptr,
252274
kernel_args_,
253275
kernel_args_num_,
254-
func_node->cuda_axis_info.grid_dim(0), // grid_x
255-
func_node->cuda_axis_info.grid_dim(1), // grid_y
256-
func_node->cuda_axis_info.grid_dim(2), // grid_z
257-
func_node->cuda_axis_info.block_dim(0), // block_x
258-
func_node->cuda_axis_info.block_dim(1), // block_y
259-
func_node->cuda_axis_info.block_dim(2), // block_z
260-
shared_mem_bytes.value(), // shared_mem
276+
kernel_args_int64.at(0), // grid_x
277+
kernel_args_int64.at(1), // grid_y
278+
kernel_args_int64.at(2), // grid_z
279+
kernel_args_int64.at(3), // block_x
280+
kernel_args_int64.at(4), // block_y
281+
kernel_args_int64.at(5), // block_z
282+
kernel_args_int64.at(6), // shared_mem
261283
kernel_stream_},
262284
{},
263285
ir::CallType::Extern,
@@ -335,7 +357,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
335357
ir::CallType::Extern,
336358
ir::FunctionRef(),
337359
0);
338-
ir::Expr let_symbol = ir::Expr(args[i].var_arg());
360+
ir::Expr let_symbol = ir::ir_utils::IRCopy(args[i].var_arg());
339361
let_symbol->set_type(type_of<int64_t>());
340362
ir::stmt::StmtRef stmt =
341363
ir::stmt::Let(let_symbol, call_get_value_in_kernel_args);

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
PD_DECLARE_bool(cinn_use_cuda_vectorize);
5050
PD_DECLARE_bool(cinn_check_tensor_buffer_map);
51+
PD_DECLARE_bool(cinn_longlong2int);
5152
const int default_priority = 100;
5253

5354
namespace cinn {
@@ -195,49 +196,48 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
195196
// including preparing function args and temporary variables,
196197
// applying low-level optimization passes, etc.
197198
std::vector<ir::Expr> scheduled_func_bodies;
199+
std::vector<ir::SymbolicPredicate> predicates;
198200
for (std::pair<ir::SymbolicPredicate, ir::Expr>& cond2body :
199201
cond2func_bodies) {
202+
predicates.push_back(cond2body.first);
200203
scheduled_func_bodies.push_back(cond2body.second);
201204
}
202205
std::vector<ir::Tensor> group_func_arg_tensors_copy = group_func_arg_tensors;
203206
std::vector<ir::Argument> group_func_args;
204207
std::vector<ir::Tensor> infer_shape_tensor_args;
205-
std::vector<ir::LoweredFunc> funcs = PostProcess(group,
206-
tensor_map,
207-
{scheduled_func_bodies},
208-
&group_func_arg_tensors_copy,
209-
&group_func_args,
210-
&infer_shape_tensor_args);
208+
209+
std::vector<CondFuncPriorWrapper> warps_processed =
210+
PostProcess(group,
211+
tensor_map,
212+
fusion_group_info,
213+
{scheduled_func_bodies},
214+
{predicates},
215+
{priorities},
216+
&group_func_arg_tensors_copy,
217+
&group_func_args,
218+
&infer_shape_tensor_args);
211219
if (FLAGS_cinn_check_tensor_buffer_map) {
212-
for (ir::LoweredFunc& func : funcs) {
213-
optim::CheckTensorBufferMap(func->body, "BucketLower PostProcess");
220+
for (auto& warp : warps_processed) {
221+
optim::CheckTensorBufferMap(std::get<1>(warp)->body,
222+
"BucketLower PostProcess");
214223
}
215224
VLOG(3) << "PostProcess tensor-buffer map check succeed";
216225
}
217-
PADDLE_ENFORCE_EQ(funcs.size(),
218-
cond2func_bodies.size(),
219-
::common::errors::InvalidArgument(
220-
"The size of funcs and cond2func_bodies should be "
221-
"the same."));
222-
PADDLE_ENFORCE_EQ(funcs.size(),
223-
priorities.size() + 1,
224-
::common::errors::InvalidArgument(
225-
"The size of funcs should equals to the "
226-
"size of priorities plus one."));
226+
227227
BucketLoweredFuncsWrapper funcs_wrapper;
228-
for (int i = 0; i < funcs.size() - 1; ++i) {
229-
funcs_wrapper.predicate2funcs.emplace_back(
230-
std::make_tuple(cond2func_bodies[i].first, funcs[i], priorities[i]));
228+
for (int i = 0; i < warps_processed.size() - 1; ++i) {
229+
funcs_wrapper.predicate2funcs.emplace_back(warps_processed[i]);
231230
}
231+
232232
// The last func is x86 kernel.
233-
for (size_t i = funcs.size() - 1; i < funcs.size(); ++i) {
234-
if (funcs[i]->body == ir::Expr(-1)) {
235-
continue;
236-
}
237-
funcs[i]->name = funcs[i]->name + "_CX86";
238-
funcs_wrapper.predicate2funcsCX86.emplace_back(cond2func_bodies[i].first,
239-
funcs[i]);
233+
auto [predicate_postprocessed, func_postprocessed, _] =
234+
warps_processed[warps_processed.size() - 1];
235+
if (func_postprocessed->body != ir::Expr(-1)) {
236+
func_postprocessed->name = func_postprocessed->name + "_CX86";
237+
funcs_wrapper.predicate2funcsCX86.emplace_back(predicate_postprocessed,
238+
func_postprocessed);
240239
}
240+
241241
funcs_wrapper.infer_shape_func =
242242
GenerateInferShapeFunc(group, infer_shape_tensor_args, group_func_args);
243243

@@ -258,13 +258,18 @@ std::unordered_set<std::string> CollectStoreBufferNames(
258258
return buffer_names;
259259
}
260260

261-
std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
261+
std::vector<CondFuncPriorWrapper> OpLowererImpl::PostProcess(
262262
const OpLoweringGroupPtr& group,
263263
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
264+
const std::shared_ptr<FusionGroupInfo>& fusion_group_info,
264265
std::vector<ir::Expr> func_bodies,
266+
std::vector<ir::SymbolicPredicate> predicates,
267+
std::vector<int> priorities,
265268
std::vector<ir::Tensor>* group_func_arg_tensors,
266269
std::vector<ir::Argument>* group_func_args,
267270
std::vector<ir::Tensor>* infer_shape_arg_tensor) {
271+
std::vector<ir::Expr> inputs_element_size;
272+
268273
// 1.Prepare function args
269274
group->mut_input_names().clear();
270275
std::unordered_set<std::string> store_buffer_names =
@@ -280,6 +285,12 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
280285
? ir::Argument::IO::kOutput
281286
: ir::Argument::IO::kInput;
282287
(*group_func_args).emplace_back(arg_tensor->buffer, io_type);
288+
// collect element size for longlong2int pass.
289+
if (FLAGS_cinn_longlong2int) {
290+
inputs_element_size.push_back(common::FoldExpr(
291+
[](const Expr& a, const Expr& b) { return ir::Mul::Make(a, b); },
292+
arg_tensor->shape));
293+
}
283294
arg_name_set.insert(arg_tensor->buffer->name);
284295
}
285296

@@ -330,6 +341,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
330341
std::map<int, CINNKernelInfo::SymbolArgBindInfo> mps;
331342
// update args for dynamic dim
332343
int non_tensor_arg_idx = group_func_args->size();
344+
333345
std::unordered_set<std::string> symbol_args_set;
334346
for (int tensor_arg_idx = 0; tensor_arg_idx < input_tensor_size;
335347
tensor_arg_idx++) {
@@ -381,7 +393,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
381393
AddDimSymbolArgs();
382394
AddValueSymbolArgs();
383395
}
384-
std::vector<ir::LoweredFunc> lowered_funcs;
396+
397+
std::vector<ir::LoweredFunc> ret_lowered_funcs;
398+
std::vector<ir::SymbolicPredicate> ret_predicates;
399+
std::vector<int> ret_priorities;
400+
385401
for (int i = 0; i < func_bodies.size(); ++i) {
386402
ir::Expr func_body = func_bodies[i];
387403
optim::EliminateDeadScheduleBlock(&(func_body), group->output_names());
@@ -416,14 +432,53 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
416432
func = optim::Optimize(func, common::DefaultHostTarget(), false);
417433
}
418434
func->num_output_tensors = infer_shape_arg_tensor->size();
419-
lowered_funcs.push_back(std::move(func));
420-
}
421435

422-
// 5. Unify temp_space args and set temp_space sizes
423-
UnifyTempSpaceArgs(&lowered_funcs);
424-
group->mut_temp_space_sizes() = CollectTempSpaceSizes(lowered_funcs);
436+
// 5. Apply longlong2int pass
437+
if (i != func_bodies.size() - 1) {
438+
LongLong2Int(symbol_args_set,
439+
fusion_group_info->loop_ranges_expr,
440+
inputs_element_size,
441+
priorities[i],
442+
&predicates[i],
443+
&func,
444+
&ret_predicates,
445+
&ret_lowered_funcs,
446+
&ret_priorities);
447+
}
448+
ret_predicates.push_back(std::move(predicates[i]));
449+
ret_lowered_funcs.push_back(std::move(func));
450+
// host func has no priority, since tuples require alignment, set -1 here.
451+
if (i != func_bodies.size() - 1) {
452+
ret_priorities.push_back(std::move(priorities[i]));
453+
} else {
454+
ret_priorities.push_back(-1);
455+
}
456+
}
425457

426-
return lowered_funcs;
458+
// 6. Unify temp_space args and set temp_space sizes
459+
UnifyTempSpaceArgs(&ret_lowered_funcs);
460+
group->mut_temp_space_sizes() = CollectTempSpaceSizes(ret_lowered_funcs);
461+
462+
PADDLE_ENFORCE_EQ(
463+
ret_lowered_funcs.size(),
464+
ret_predicates.size(),
465+
::common::errors::InvalidArgument(
466+
"The size of ret_lowered_funcs and ret_predicates should be "
467+
"the same."));
468+
PADDLE_ENFORCE_EQ(
469+
ret_lowered_funcs.size(),
470+
ret_priorities.size(),
471+
::common::errors::InvalidArgument(
472+
"The size of ret_lowered_funcs and ret_priorities should be "
473+
"the same."));
474+
475+
std::vector<CondFuncPriorWrapper> ret;
476+
for (size_t i = 0; i < ret_lowered_funcs.size(); ++i) {
477+
ret.emplace_back(std::move(ret_predicates[i]),
478+
std::move(ret_lowered_funcs[i]),
479+
std::move(ret_priorities[i]));
480+
}
481+
return ret;
427482
}
428483

429484
std::vector<ir::stmt::BlockRef> OpLowererImpl::LowerOps(

paddle/cinn/hlir/framework/pir/op_lowering_impl.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ namespace pir {
4141

4242
class PrettyNamer;
4343
using OpLoweringGroupPtr = std::shared_ptr<OpLoweringGroup>;
44+
using CondFuncPriorWrapper =
45+
std::tuple<ir::SymbolicPredicate, ir::LoweredFunc, int>;
4446

4547
using cinn::common::Target;
4648
class OpLowererImpl;
@@ -66,15 +68,21 @@ class OpLowererImpl : public OpLowererImplBase<OpLoweringGroupPtr> {
6668
* variables, applying low-level optimization passes, etc.
6769
* @param group The group to be lowered.
6870
* @param tensor_map All tensors used for calculating the group.
71+
* @param fusion_group_info The info of the fusion group.
6972
* @param func_bodies The scheduled func bodies of group.
73+
* @param predicates The symbolic predicate of each func.
74+
* @param priorities The priority of each func.
7075
* @param group_func_arg_tensors Tensors used as the group function arguments.
7176
* @param group_func_args Arguments used as the group function arguments.
7277
* @return The lowered funcs after the post processing.
7378
*/
74-
std::vector<ir::LoweredFunc> PostProcess(
79+
std::vector<CondFuncPriorWrapper> PostProcess(
7580
const OpLoweringGroupPtr& group,
7681
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
82+
const std::shared_ptr<FusionGroupInfo>& fusion_group_info,
7783
std::vector<ir::Expr> func_bodies,
84+
std::vector<ir::SymbolicPredicate> predicates,
85+
std::vector<int> priorities,
7886
std::vector<ir::Tensor>* group_func_arg_tensors,
7987
std::vector<ir::Argument>* group_func_args,
8088
std::vector<ir::Tensor>* infer_shape_arg_tensor);

0 commit comments

Comments
 (0)