Skip to content

Commit e7eb0e2

Browse files
authored
fix paddle-ort python bug (#42464)
* fix paddle-ort python bug * fix paddle-ort python bug
1 parent be77aee commit e7eb0e2

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

paddle/fluid/inference/api/details/zero_copy_tensor.cc

+33-2
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) {
674674
OrtMemTypeDefault);
675675
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
676676
std::multiplies<size_t>());
677-
auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size,
678-
shape_.data(), shape_.size());
677+
size_t buffer_size = size * sizeof(T);
678+
if (buffer_size > buffer_.size()) {
679+
buffer_.resize(buffer_size);
680+
}
681+
std::memcpy(static_cast<void *>(buffer_.data()), data, buffer_size);
682+
683+
auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
684+
if (std::is_same<T, float>::value) {
685+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
686+
} else if (std::is_same<T, double>::value) {
687+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
688+
} else if (std::is_same<T, int64_t>::value) {
689+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
690+
} else if (std::is_same<T, int32_t>::value) {
691+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
692+
} else if (std::is_same<T, uint8_t>::value) {
693+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
694+
} else if (std::is_same<T, int8_t>::value) {
695+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
696+
} else if (std::is_same<T, float16>::value) {
697+
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
698+
}
699+
700+
if (onnx_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
701+
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
702+
"Found undefined data type for onnxruntime, only supports "
703+
"float16/float32/float64/int8/uint8/int32/int64."));
704+
}
705+
706+
auto ort_value =
707+
Ort::Value::CreateTensor(memory_info, buffer_.data(), buffer_size,
708+
shape_.data(), shape_.size(), onnx_dtype);
709+
679710
binding->BindInput(name_.c_str(), ort_value);
680711
}
681712

paddle/fluid/inference/api/paddle_tensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class PD_INFER_DECL Tensor {
187187
#ifdef PADDLE_WITH_ONNXRUNTIME
188188
bool is_ort_tensor_{false};
189189
std::vector<int64_t> shape_;
190+
std::vector<int8_t> buffer_;
190191
std::weak_ptr<Ort::IoBinding> binding_;
191192
int idx_{-1};
192193

0 commit comments

Comments
 (0)