Skip to content

Commit c023d88

Browse files
authored
[CINN] Update cinn jit data (#72510)
* [CINN] Update cinn jit data * refine
1 parent da056ba commit c023d88

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc

+12-4
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,14 @@ class CinnJitInstruction::FnPtrImpl {
103103

104104
// Pass real tensor data to cinn_buffer_t func args placeholder
105105
for (size_t i = 0; i < kernel_tensor_args.size(); ++i) {
106-
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory =
107-
reinterpret_cast<uint8_t*>(kernel_tensor_args[i]->data());
106+
if (!kernel_tensor_args[i]->has_allocation()) {
107+
VLOG(2) << "WARNING! Access DenseTensor::data() without allocation, "
108+
"return nullptr!";
109+
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory = nullptr;
110+
} else {
111+
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory =
112+
reinterpret_cast<uint8_t*>(kernel_tensor_args[i]->data());
113+
}
108114
}
109115

110116
// Launch host kernel
@@ -297,6 +303,7 @@ CinnJitInstruction::CinnJitInstruction(
297303
ir_dims_.push_back(
298304
result.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
299305
tensor_args_.push_back(tensor);
306+
alloc_tensors_.push_back(tensor);
300307
auto alloc_tensor_type =
301308
result.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
302309
tensor->set_type(
@@ -321,6 +328,7 @@ CinnJitInstruction::CinnJitInstruction(
321328
}
322329
for (auto& tensor : temp_space_tensors_) {
323330
tensor_args_.push_back(&tensor);
331+
alloc_tensors_.push_back(&tensor);
324332
}
325333
output_tensor_size += temp_space_tensors_.size();
326334
}
@@ -343,8 +351,8 @@ void CinnJitInstruction::Run() {
343351
fn_ptr_impl_->InferShape(
344352
tensor_args_, ir_dims_, input_tensor_size, output_tensor_size);
345353
}
346-
for (size_t i = 0; i < tensor_args_.size(); ++i) {
347-
dev_ctx_->Alloc(tensor_args_[i], tensor_args_[i]->dtype());
354+
for (size_t i = 0; i < alloc_tensors_.size(); ++i) {
355+
dev_ctx_->Alloc(alloc_tensors_[i], alloc_tensors_[i]->dtype());
348356
}
349357

350358
// 2. execute kernel

paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class CinnJitInstruction : public InstructionBase {
5555

5656
bool need_update_shape{false};
5757
std::vector<phi::DenseTensor*> tensor_args_;
58+
std::vector<phi::DenseTensor*> alloc_tensors_;
5859
std::vector<phi::DDim> ir_dims_;
5960

6061
// Tensors that hold the temporary spaces used by the kernel. These tensors

0 commit comments

Comments
 (0)