|
18 | 18 |
|
19 | 19 | #pragma once
|
20 | 20 |
|
21 |
| -#include <iostream> |
22 |
| -#include <optional> |
23 |
| -#include <variant> |
24 |
| -#include "paddle/common/macros.h" |
25 |
| -#include "paddle/phi/api/include/api.h" |
26 |
| -#include "paddle/phi/api/include/tensor.h" |
27 |
| -#include "paddle/phi/common/bfloat16.h" |
28 |
| -#include "paddle/phi/common/complex.h" |
29 |
| -#include "paddle/phi/common/data_type.h" |
30 |
| -#include "paddle/phi/common/float16.h" |
31 |
| -#include "paddle/phi/common/float8_e4m3fn.h" |
32 |
| -#include "paddle/phi/common/float8_e5m2.h" |
33 |
| -#include "paddle/phi/common/place.h" |
34 |
| -#include "paddle/phi/core/ddim.h" |
35 |
| - |
36 |
| -namespace logging { |
37 |
| -#define UNSUPPORTED_FEATURE_IN_PADDLE(feature) \ |
38 |
| - std::cerr << "Unsupported feature in Paddle: " << feature << std::endl; \ |
39 |
| - std::abort(); |
40 |
| -} // namespace logging |
41 |
| - |
42 |
| -namespace c10 { |
43 |
| -template <typename T> |
44 |
| -using complex = ::phi::dtype::complex<T>; |
45 |
| -using Half = ::phi::dtype::float16; |
46 |
| -using Float8_e5m2 = ::phi::dtype::float8_e5m2; |
47 |
| -using Float8_e4m3fn = ::phi::dtype::float8_e4m3fn; |
48 |
| -using BFloat16 = ::phi::dtype::bfloat16; |
49 |
| -// using |
50 |
| -} // namespace c10 |
51 |
| - |
52 |
| -namespace at { |
53 |
| - |
54 |
| -using Scalar = paddle::experimental::Scalar; |
55 |
| -struct Device {}; |
56 |
| -struct Layout {}; |
57 |
| - |
58 |
| -using PaddleTensor = paddle::Tensor; |
59 |
| - |
60 |
| -// IntArrayRef |
61 |
| -template <typename T> |
62 |
| -class ArrayRef { |
63 |
| - private: |
64 |
| - /// The start of the array, in an external buffer. |
65 |
| - const T* Data; |
66 |
| - |
67 |
| - /// The number of elements. |
68 |
| - size_t Length; |
69 |
| - |
70 |
| - public: |
71 |
| - /// Construct an empty ArrayRef. |
72 |
| - /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} |
73 |
| - |
74 |
| - constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} // NOLINT |
75 |
| - |
76 |
| - /// Construct an ArrayRef from a pointer and length. |
77 |
| - constexpr ArrayRef(const T* data, size_t length) |
78 |
| - : Data(data), Length(length) {} |
79 |
| - |
80 |
| - /// Construct an ArrayRef from a range. |
81 |
| - constexpr ArrayRef(const T* begin, const T* end) |
82 |
| - : Data(begin), Length(end - begin) {} |
83 |
| - |
84 |
| - /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec) |
85 |
| - : Data(std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr) |
86 |
| - : std::begin(Vec)), |
87 |
| - Length(Vec.size()) {} |
88 |
| - |
89 |
| - const paddle::IntArray _PD_ToPaddleIntArray() const { |
90 |
| - return paddle::IntArray(Data, Length); |
91 |
| - } |
92 |
| -}; |
93 |
| -using IntArrayRef = ArrayRef<int64_t>; |
94 |
| - |
95 |
| -enum class PADDLE_API MemoryFormat : int8_t { |
96 |
| - Contiguous, |
97 |
| - Preserve, |
98 |
| - ChannelsLast, |
99 |
| - ChannelsLast3d, |
100 |
| - NumOptions |
101 |
| -}; |
102 |
| - |
103 |
| -// Datatype |
104 |
| -using Half = c10::Half; |
105 |
| -using BFloat16 = c10::BFloat16; |
106 |
| - |
107 |
| -// ScalarType |
108 |
| -#define FORALL_PADDLE_AND_TORCH_DTYPES(_) \ |
109 |
| - _(uint8_t, UINT8, Byte) \ |
110 |
| - _(int8_t, INT8, Char) \ |
111 |
| - _(int16_t, INT16, Short) \ |
112 |
| - _(int32_t, INT32, Int) \ |
113 |
| - _(int64_t, INT64, Long) \ |
114 |
| - _(at::Half, FLOAT16, Half) \ |
115 |
| - _(float, FLOAT32, Float) \ |
116 |
| - _(double, FLOAT64, Double) \ |
117 |
| - _(c10::complex<float>, COMPLEX64, ComplexFloat) \ |
118 |
| - _(c10::complex<double>, COMPLEX128, ComplexDouble) \ |
119 |
| - _(bool, BOOL, Bool) \ |
120 |
| - _(at::BFloat16, BFLOAT16, BFloat16) \ |
121 |
| - _(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) \ |
122 |
| - _(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) \ |
123 |
| - _(uint16_t, UINT16, UInt16) \ |
124 |
| - _(uint32_t, UINT32, UInt32) |
125 |
| - |
126 |
| -enum class PADDLE_API ScalarType : int8_t { |
127 |
| -#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n, |
128 |
| - FORALL_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_ENUM_VAL_) |
129 |
| -#undef DEFINE_ENUM_ST_ENUM_VAL_ |
130 |
| - Undefined, |
131 |
| - NumOptions |
132 |
| -}; |
133 |
| - |
134 |
| -struct PADDLE_API TensorOptions { |
135 |
| - TensorOptions() |
136 |
| - : requires_grad_(false), |
137 |
| - pinned_memory_(false), |
138 |
| - has_device_(false), |
139 |
| - has_dtype_(false), |
140 |
| - has_layout_(false), |
141 |
| - has_requires_grad_(false), |
142 |
| - has_pinned_memory_(false), |
143 |
| - has_memory_format_(false) {} |
144 |
| - |
145 |
| - ScalarType _PD_GetScalarType() const { return dtype; } |
146 |
| - ::phi::Place _PD_GetPlace() const { return place; } |
147 |
| - |
148 |
| - private: |
149 |
| - ScalarType dtype = ScalarType::Float; |
150 |
| - ::phi::Place place = ::phi::CPUPlace(); |
151 |
| - bool requires_grad_ : 1; |
152 |
| - bool pinned_memory_ : 1; |
153 |
| - |
154 |
| - bool has_device_ : 1; |
155 |
| - bool has_dtype_ : 1; |
156 |
| - bool has_layout_ : 1; |
157 |
| - bool has_requires_grad_ : 1; |
158 |
| - bool has_pinned_memory_ : 1; |
159 |
| - bool has_memory_format_ : 1; |
160 |
| -}; |
161 |
| - |
162 |
| -namespace conversion { |
163 |
| -inline IntArrayRef _PD_PhiDDimToIntArrayRef(const phi::DDim& ddim) { |
164 |
| - return IntArrayRef(ddim.Get(), ddim.size()); |
165 |
| -} |
166 |
| - |
167 |
| -inline phi::DataType _PD_AtenScalarTypeToPhiDataType(ScalarType dtype) { |
168 |
| - switch (dtype) { |
169 |
| -#define DEFINE_ST_TO_DT_CASE_(_1, _dt, _st) \ |
170 |
| - case ScalarType::_st: \ |
171 |
| - return phi::DataType::_dt; |
172 |
| - FORALL_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_TO_DT_CASE_) |
173 |
| -#undef DEFINE_ST_TO_DT_CASE_ |
174 |
| - default: |
175 |
| - UNSUPPORTED_FEATURE_IN_PADDLE("Unsupported ScalarType") |
176 |
| - } |
177 |
| -} // namespace conversion |
178 |
| - |
179 |
| -} // namespace conversion |
180 |
| -class PADDLE_API Tensor { |
181 |
| - public: |
182 |
| - Tensor(const PaddleTensor& tensor) : tensor_(tensor){}; // NOLINT |
183 |
| - |
184 |
| - void* data_ptr() { return tensor_.data(); } |
185 |
| - template <typename T> |
186 |
| - T* data_ptr() const { |
187 |
| - return const_cast<T*>(tensor_.data<T>()); |
188 |
| - } |
189 |
| - int64_t stride(int64_t dim) const { |
190 |
| - return tensor_.strides()[static_cast<int>(dim)]; |
191 |
| - } |
192 |
| - IntArrayRef strides() const { |
193 |
| - return conversion::_PD_PhiDDimToIntArrayRef(tensor_.strides()); |
194 |
| - } |
195 |
| - |
196 |
| - int64_t size(int64_t dim) const { |
197 |
| - return tensor_.dims()[static_cast<int>(dim)]; |
198 |
| - } |
199 |
| - |
200 |
| - IntArrayRef sizes() const { |
201 |
| - return conversion::_PD_PhiDDimToIntArrayRef(tensor_.dims()); |
202 |
| - } |
203 |
| - |
204 |
| - int64_t numel() const { return tensor_.numel(); } |
205 |
| - |
206 |
| - at::Tensor contiguous( |
207 |
| - MemoryFormat memory_format = MemoryFormat::Contiguous) const { |
208 |
| - if (memory_format != MemoryFormat::Contiguous) { |
209 |
| - UNSUPPORTED_FEATURE_IN_PADDLE("`MemoryFormat` other than Contiguous") |
210 |
| - } |
211 |
| - return tensor_.contiguous(); |
212 |
| - } |
213 |
| - |
214 |
| - TensorOptions options() const { |
215 |
| - // TODO(SigureMo): Implement this |
216 |
| - return TensorOptions(); |
217 |
| - } |
218 |
| - |
219 |
| - PaddleTensor _PD_GetInner() const { return tensor_; } |
220 |
| - |
221 |
| - private: |
222 |
| - PaddleTensor tensor_; |
223 |
| -}; |
224 |
| - |
225 |
| -at::Tensor empty( |
226 |
| - at::IntArrayRef size, |
227 |
| - at::TensorOptions options = {}, |
228 |
| - ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) { |
229 |
| - if (memory_format.has_value()) { |
230 |
| - UNSUPPORTED_FEATURE_IN_PADDLE("`MemoryFormat`") |
231 |
| - } |
232 |
| - return paddle::experimental::empty( |
233 |
| - size._PD_ToPaddleIntArray(), |
234 |
| - conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()), |
235 |
| - options._PD_GetPlace()); |
236 |
| -} |
237 |
| -at::Tensor ones(at::IntArrayRef size, at::TensorOptions options = {}) { |
238 |
| - return paddle::experimental::ones( |
239 |
| - size._PD_ToPaddleIntArray(), |
240 |
| - conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()), |
241 |
| - options._PD_GetPlace()); |
242 |
| -} |
243 |
| -at::Tensor zeros(at::IntArrayRef size, at::TensorOptions options = {}) { |
244 |
| - return paddle::experimental::zeros( |
245 |
| - size._PD_ToPaddleIntArray(), |
246 |
| - conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()), |
247 |
| - options._PD_GetPlace()); |
248 |
| -} |
249 |
| - |
250 |
| -at::Tensor full(at::IntArrayRef size, |
251 |
| - const at::Scalar& fill_value, |
252 |
| - ::std::optional<at::ScalarType> dtype = {}, |
253 |
| - ::std::optional<at::Layout> layout = {}, |
254 |
| - ::std::optional<at::Device> device = {}, |
255 |
| - ::std::optional<bool> pin_memory = {}) { |
256 |
| - if (pin_memory.has_value()) { |
257 |
| - UNSUPPORTED_FEATURE_IN_PADDLE("`pin_memory` option in full") |
258 |
| - } |
259 |
| - return paddle::experimental::full( |
260 |
| - size._PD_ToPaddleIntArray(), |
261 |
| - fill_value, |
262 |
| - dtype.has_value() ? conversion::_PD_AtenScalarTypeToPhiDataType(*dtype) |
263 |
| - : phi::DataType::FLOAT32, |
264 |
| - phi::CPUPlace() // TODO(SigureMo): support other places |
265 |
| - ); |
266 |
| -} |
267 |
| - |
268 |
| -} // namespace at |
269 |
| - |
270 |
| -namespace torch { |
271 |
| -using Tensor = at::Tensor; |
272 |
| -using Dtype = at::ScalarType; |
273 |
| -} // namespace torch |
274 |
| - |
275 |
| -void compiling_test() { |
276 |
| - // Example usage of the Tensor class |
277 |
| - at::Tensor a = at::ones({2, 3}, at::TensorOptions()); |
278 |
| - at::Tensor b = at::full({2, 3}, 1, at::ScalarType::Float); |
279 |
| - double c = 10; |
280 |
| - at::Tensor a_contig = a.contiguous(); |
281 |
| - at::Tensor b_contig = b.contiguous(); |
282 |
| - at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); |
283 |
| - const float* a_ptr = a_contig.data_ptr<float>(); |
284 |
| - const float* b_ptr = b_contig.data_ptr<float>(); |
285 |
| - float* result_ptr = result.data_ptr<float>(); |
286 |
| - for (int64_t i = 0; i < a_contig.numel(); i++) { |
287 |
| - result_ptr[i] = a_ptr[i] * b_ptr[i] + c; |
288 |
| - } |
289 |
| - // Show result |
290 |
| - for (int64_t i = 0; i < a_contig.numel(); i++) { |
291 |
| - std::cout << "Result[" << i << "] = " << a_ptr[i] * b_ptr[i] + c |
292 |
| - << std::endl; |
293 |
| - } |
294 |
| -} |
| 21 | +#include "paddle/phi/api/include/torch_like_api/torch/api.h" |
0 commit comments