Skip to content

Commit c25dd08

Browse files
committed
[CINN] prefetch load scalar tensor for vectorize situation
1 parent a8980c1 commit c25dd08

File tree

4 files changed

+227
-21
lines changed

4 files changed

+227
-21
lines changed

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
433433
} else {
434434
func = optim::Optimize(func, common::DefaultHostTarget(), false);
435435
}
436+
auto pre_load_temp_buffers =
437+
lang::GetPreLoadTempBufferAfterVectorize(func->body);
438+
func->temp_bufs.insert(func->temp_bufs.end(),
439+
pre_load_temp_buffers.begin(),
440+
pre_load_temp_buffers.end());
436441
func->num_output_tensors = infer_shape_arg_tensor->size();
437442
lowered_funcs.push_back(std::move(func));
438443
}

paddle/cinn/lang/lower.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,22 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument>& args,
228228
return temp_buffers;
229229
}
230230

231+
std::vector<ir::Buffer> GetPreLoadTempBufferAfterVectorize(Expr body) {
232+
std::unordered_set<std::string> buffer_names;
233+
std::vector<ir::Buffer> temp_buffers;
234+
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
235+
if (x->as_tensor() && x->as_tensor()->buffer.defined() &&
236+
!buffer_names.count(x->as_tensor()->buffer->name) &&
237+
utils::StartsWith(x->as_tensor()->buffer->name, "pre_load")) {
238+
buffer_names.insert(x->as_tensor()->buffer->name);
239+
temp_buffers.push_back(x->as_tensor()->buffer);
240+
return true;
241+
}
242+
return false;
243+
});
244+
return std::move(temp_buffers);
245+
}
246+
231247
std::set<ir::Tensor> CollectTempTensorsFromCtrlDepends(
232248
ast_gen_ius::TensorGroup* tensor_group,
233249
const std::vector<Tensor>& tensor_args) {

paddle/cinn/lang/lower.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,7 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument> &args,
5959
std::vector<ir::Buffer> GetTempBuffers(
6060
const std::vector<cinn::ir::Tensor> &tensor_args, Expr body);
6161

62+
std::vector<ir::Buffer> GetPreLoadTempBufferAfterVectorize(Expr body);
63+
6264
} // namespace lang
6365
} // namespace cinn

0 commit comments

Comments
 (0)