@@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) {
674
674
OrtMemTypeDefault);
675
675
size_t size = std::accumulate (begin (shape_), end (shape_), 1UL ,
676
676
std::multiplies<size_t >());
677
- auto ort_value = GetOrtVaule (memory_info, const_cast <T *>(data), size,
678
- shape_.data (), shape_.size ());
677
+ size_t buffer_size = size * sizeof (T);
678
+ if (buffer_size > buffer_.size ()) {
679
+ buffer_.resize (buffer_size);
680
+ }
681
+ std::memcpy (static_cast <void *>(buffer_.data ()), data, buffer_size);
682
+
683
+ auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
684
+ if (std::is_same<T, float >::value) {
685
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
686
+ } else if (std::is_same<T, double >::value) {
687
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
688
+ } else if (std::is_same<T, int64_t >::value) {
689
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
690
+ } else if (std::is_same<T, int32_t >::value) {
691
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
692
+ } else if (std::is_same<T, uint8_t >::value) {
693
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
694
+ } else if (std::is_same<T, int8_t >::value) {
695
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
696
+ } else if (std::is_same<T, float16>::value) {
697
+ onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
698
+ }
699
+
700
+ if (onnx_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
701
+ PADDLE_THROW (paddle::platform::errors::InvalidArgument (
702
+ " Found undefined data type for onnxruntime, only supports "
703
+ " float16/float32/float64/int8/uint8/int32/int64." ));
704
+ }
705
+
706
+ auto ort_value =
707
+ Ort::Value::CreateTensor (memory_info, buffer_.data (), buffer_size,
708
+ shape_.data (), shape_.size (), onnx_dtype);
709
+
679
710
binding->BindInput (name_.c_str (), ort_value);
680
711
}
681
712
0 commit comments