Skip to content

Commit b4e16a3

Browse files
authored
[CINN] Add type casting for iter_vals (PaddlePaddle#72204)
1 parent 4a6fd1e commit b4e16a3

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

paddle/cinn/ast_gen_ius/ast_gen.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ StmtRef AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
310310
block_vars.push_back(Var(
311311
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
312312
optim::ReplaceVarWithExpr(body, axis[i], block_vars[i]);
313+
if (shape[i].type() == Int(64)) axis_vars[i]->set_type(Int(64));
313314
axis_vars[i]->is_reduce_axis = false;
314315
iter_values.push_back(axis_vars[i]);
315316
}

paddle/cinn/hlir/pe/elementwise.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ ir::Tensor Tril(const ir::Tensor& A,
338338
"The Tril op input tensor must have a rank "
339339
"greater than or equal to 2."));
340340
std::vector<Expr> new_indice(indice.end() - 2, indice.end());
341-
Expr col_indice = indice.back();
342341
return ir::Select::Make(new_indice[0] >= new_indice[1] - diagonal,
343342
A(indice),
344343
ir::Zero(A->type()));

0 commit comments

Comments
 (0)