Skip to content

Commit 9713c09

Browse files
committed
split into separate files
1 parent 013a19b commit 9713c09

File tree

16 files changed

+620
-274
lines changed

16 files changed

+620
-274
lines changed

paddle/phi/api/include/torch_compat_runtime.h

Lines changed: 1 addition & 274 deletions
Original file line numberDiff line numberDiff line change
@@ -18,277 +18,4 @@
1818

1919
#pragma once
2020

21-
#include <iostream>
22-
#include <optional>
23-
#include <variant>
24-
#include "paddle/common/macros.h"
25-
#include "paddle/phi/api/include/api.h"
26-
#include "paddle/phi/api/include/tensor.h"
27-
#include "paddle/phi/common/bfloat16.h"
28-
#include "paddle/phi/common/complex.h"
29-
#include "paddle/phi/common/data_type.h"
30-
#include "paddle/phi/common/float16.h"
31-
#include "paddle/phi/common/float8_e4m3fn.h"
32-
#include "paddle/phi/common/float8_e5m2.h"
33-
#include "paddle/phi/common/place.h"
34-
#include "paddle/phi/core/ddim.h"
35-
36-
namespace logging {
37-
#define UNSUPPORTED_FEATURE_IN_PADDLE(feature) \
38-
std::cerr << "Unsupported feature in Paddle: " << feature << std::endl; \
39-
std::abort();
40-
} // namespace logging
41-
42-
namespace c10 {
43-
template <typename T>
44-
using complex = ::phi::dtype::complex<T>;
45-
using Half = ::phi::dtype::float16;
46-
using Float8_e5m2 = ::phi::dtype::float8_e5m2;
47-
using Float8_e4m3fn = ::phi::dtype::float8_e4m3fn;
48-
using BFloat16 = ::phi::dtype::bfloat16;
49-
// using
50-
} // namespace c10
51-
52-
namespace at {
53-
54-
using Scalar = paddle::experimental::Scalar;
55-
struct Device {};
56-
struct Layout {};
57-
58-
using PaddleTensor = paddle::Tensor;
59-
60-
// IntArrayRef
61-
template <typename T>
62-
class ArrayRef {
63-
private:
64-
/// The start of the array, in an external buffer.
65-
const T* Data;
66-
67-
/// The number of elements.
68-
size_t Length;
69-
70-
public:
71-
/// Construct an empty ArrayRef.
72-
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
73-
74-
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} // NOLINT
75-
76-
/// Construct an ArrayRef from a pointer and length.
77-
constexpr ArrayRef(const T* data, size_t length)
78-
: Data(data), Length(length) {}
79-
80-
/// Construct an ArrayRef from a range.
81-
constexpr ArrayRef(const T* begin, const T* end)
82-
: Data(begin), Length(end - begin) {}
83-
84-
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
85-
: Data(std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
86-
: std::begin(Vec)),
87-
Length(Vec.size()) {}
88-
89-
const paddle::IntArray _PD_ToPaddleIntArray() const {
90-
return paddle::IntArray(Data, Length);
91-
}
92-
};
93-
using IntArrayRef = ArrayRef<int64_t>;
94-
95-
enum class PADDLE_API MemoryFormat : int8_t {
96-
Contiguous,
97-
Preserve,
98-
ChannelsLast,
99-
ChannelsLast3d,
100-
NumOptions
101-
};
102-
103-
// Datatype
104-
using Half = c10::Half;
105-
using BFloat16 = c10::BFloat16;
106-
107-
// ScalarType
108-
#define FORALL_PADDLE_AND_TORCH_DTYPES(_) \
109-
_(uint8_t, UINT8, Byte) \
110-
_(int8_t, INT8, Char) \
111-
_(int16_t, INT16, Short) \
112-
_(int32_t, INT32, Int) \
113-
_(int64_t, INT64, Long) \
114-
_(at::Half, FLOAT16, Half) \
115-
_(float, FLOAT32, Float) \
116-
_(double, FLOAT64, Double) \
117-
_(c10::complex<float>, COMPLEX64, ComplexFloat) \
118-
_(c10::complex<double>, COMPLEX128, ComplexDouble) \
119-
_(bool, BOOL, Bool) \
120-
_(at::BFloat16, BFLOAT16, BFloat16) \
121-
_(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) \
122-
_(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) \
123-
_(uint16_t, UINT16, UInt16) \
124-
_(uint32_t, UINT32, UInt32)
125-
126-
enum class PADDLE_API ScalarType : int8_t {
127-
#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n,
128-
FORALL_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_ENUM_VAL_)
129-
#undef DEFINE_ENUM_ST_ENUM_VAL_
130-
Undefined,
131-
NumOptions
132-
};
133-
134-
struct PADDLE_API TensorOptions {
135-
TensorOptions()
136-
: requires_grad_(false),
137-
pinned_memory_(false),
138-
has_device_(false),
139-
has_dtype_(false),
140-
has_layout_(false),
141-
has_requires_grad_(false),
142-
has_pinned_memory_(false),
143-
has_memory_format_(false) {}
144-
145-
ScalarType _PD_GetScalarType() const { return dtype; }
146-
::phi::Place _PD_GetPlace() const { return place; }
147-
148-
private:
149-
ScalarType dtype = ScalarType::Float;
150-
::phi::Place place = ::phi::CPUPlace();
151-
bool requires_grad_ : 1;
152-
bool pinned_memory_ : 1;
153-
154-
bool has_device_ : 1;
155-
bool has_dtype_ : 1;
156-
bool has_layout_ : 1;
157-
bool has_requires_grad_ : 1;
158-
bool has_pinned_memory_ : 1;
159-
bool has_memory_format_ : 1;
160-
};
161-
162-
namespace conversion {
163-
inline IntArrayRef _PD_PhiDDimToIntArrayRef(const phi::DDim& ddim) {
164-
return IntArrayRef(ddim.Get(), ddim.size());
165-
}
166-
167-
inline phi::DataType _PD_AtenScalarTypeToPhiDataType(ScalarType dtype) {
168-
switch (dtype) {
169-
#define DEFINE_ST_TO_DT_CASE_(_1, _dt, _st) \
170-
case ScalarType::_st: \
171-
return phi::DataType::_dt;
172-
FORALL_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_TO_DT_CASE_)
173-
#undef DEFINE_ST_TO_DT_CASE_
174-
default:
175-
UNSUPPORTED_FEATURE_IN_PADDLE("Unsupported ScalarType")
176-
}
177-
} // namespace conversion
178-
179-
} // namespace conversion
180-
class PADDLE_API Tensor {
181-
public:
182-
Tensor(const PaddleTensor& tensor) : tensor_(tensor){}; // NOLINT
183-
184-
void* data_ptr() { return tensor_.data(); }
185-
template <typename T>
186-
T* data_ptr() const {
187-
return const_cast<T*>(tensor_.data<T>());
188-
}
189-
int64_t stride(int64_t dim) const {
190-
return tensor_.strides()[static_cast<int>(dim)];
191-
}
192-
IntArrayRef strides() const {
193-
return conversion::_PD_PhiDDimToIntArrayRef(tensor_.strides());
194-
}
195-
196-
int64_t size(int64_t dim) const {
197-
return tensor_.dims()[static_cast<int>(dim)];
198-
}
199-
200-
IntArrayRef sizes() const {
201-
return conversion::_PD_PhiDDimToIntArrayRef(tensor_.dims());
202-
}
203-
204-
int64_t numel() const { return tensor_.numel(); }
205-
206-
at::Tensor contiguous(
207-
MemoryFormat memory_format = MemoryFormat::Contiguous) const {
208-
if (memory_format != MemoryFormat::Contiguous) {
209-
UNSUPPORTED_FEATURE_IN_PADDLE("`MemoryFormat` other than Contiguous")
210-
}
211-
return tensor_.contiguous();
212-
}
213-
214-
TensorOptions options() const {
215-
// TODO(SigureMo): Implement this
216-
return TensorOptions();
217-
}
218-
219-
PaddleTensor _PD_GetInner() const { return tensor_; }
220-
221-
private:
222-
PaddleTensor tensor_;
223-
};
224-
225-
at::Tensor empty(
226-
at::IntArrayRef size,
227-
at::TensorOptions options = {},
228-
::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) {
229-
if (memory_format.has_value()) {
230-
UNSUPPORTED_FEATURE_IN_PADDLE("`MemoryFormat`")
231-
}
232-
return paddle::experimental::empty(
233-
size._PD_ToPaddleIntArray(),
234-
conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
235-
options._PD_GetPlace());
236-
}
237-
at::Tensor ones(at::IntArrayRef size, at::TensorOptions options = {}) {
238-
return paddle::experimental::ones(
239-
size._PD_ToPaddleIntArray(),
240-
conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
241-
options._PD_GetPlace());
242-
}
243-
at::Tensor zeros(at::IntArrayRef size, at::TensorOptions options = {}) {
244-
return paddle::experimental::zeros(
245-
size._PD_ToPaddleIntArray(),
246-
conversion::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
247-
options._PD_GetPlace());
248-
}
249-
250-
at::Tensor full(at::IntArrayRef size,
251-
const at::Scalar& fill_value,
252-
::std::optional<at::ScalarType> dtype = {},
253-
::std::optional<at::Layout> layout = {},
254-
::std::optional<at::Device> device = {},
255-
::std::optional<bool> pin_memory = {}) {
256-
if (pin_memory.has_value()) {
257-
UNSUPPORTED_FEATURE_IN_PADDLE("`pin_memory` option in full")
258-
}
259-
return paddle::experimental::full(
260-
size._PD_ToPaddleIntArray(),
261-
fill_value,
262-
dtype.has_value() ? conversion::_PD_AtenScalarTypeToPhiDataType(*dtype)
263-
: phi::DataType::FLOAT32,
264-
phi::CPUPlace() // TODO(SigureMo): support other places
265-
);
266-
}
267-
268-
} // namespace at
269-
270-
namespace torch {
271-
using Tensor = at::Tensor;
272-
using Dtype = at::ScalarType;
273-
} // namespace torch
274-
275-
void compiling_test() {
276-
// Example usage of the Tensor class
277-
at::Tensor a = at::ones({2, 3}, at::TensorOptions());
278-
at::Tensor b = at::full({2, 3}, 1, at::ScalarType::Float);
279-
double c = 10;
280-
at::Tensor a_contig = a.contiguous();
281-
at::Tensor b_contig = b.contiguous();
282-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
283-
const float* a_ptr = a_contig.data_ptr<float>();
284-
const float* b_ptr = b_contig.data_ptr<float>();
285-
float* result_ptr = result.data_ptr<float>();
286-
for (int64_t i = 0; i < a_contig.numel(); i++) {
287-
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
288-
}
289-
// Show result
290-
for (int64_t i = 0; i < a_contig.numel(); i++) {
291-
std::cout << "Result[" << i << "] = " << a_ptr[i] * b_ptr[i] + c
292-
<< std::endl;
293-
}
294-
}
21+
#include "paddle/phi/api/include/torch_like_api/torch/api.h"
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/api/include/torch_like_api/c10/array_ref.h"
18+
#include "paddle/phi/api/include/torch_like_api/c10/data_type.h"
19+
#include "paddle/phi/api/include/torch_like_api/c10/exception.h"
20+
#include "paddle/phi/api/include/torch_like_api/c10/memory_format.h"
21+
#include "paddle/phi/api/include/torch_like_api/c10/scalar_type.h"
22+
#include "paddle/phi/api/include/torch_like_api/c10/tensor_options.h"
23+
#include "paddle/phi/common/scalar.h"
24+
25+
namespace at {
26+
27+
// TensorOptions
28+
using c10::TensorOptions;
29+
30+
// DataType
31+
using Half = c10::Half;
32+
using BFloat16 = c10::BFloat16;
33+
34+
// ScalarType
35+
using c10::ScalarType;
36+
37+
#define REDEFINE_CONSTANT_IN_AT(_1, _2, name) \
38+
constexpr ScalarType k##name = c10::k##name;
39+
FOREACH_PADDLE_AND_TORCH_DTYPES(REDEFINE_CONSTANT_IN_AT)
40+
#undef REDEFINE_CONSTANT_IN_AT
41+
42+
// IntArrayRef
43+
using c10::IntArrayRef;
44+
45+
// MemoryFormat
46+
using c10::MemoryFormat;
47+
48+
// Scalar
49+
using Scalar = paddle::experimental::Scalar;
50+
struct Device {};
51+
struct Layout {};
52+
} // namespace at
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/api/include/api.h"
18+
#include "paddle/phi/api/include/torch_like_api/aten/common.h"
19+
#include "paddle/phi/api/include/torch_like_api/aten/tensor.h"
20+
21+
namespace at {
22+
23+
at::Tensor empty(
24+
at::IntArrayRef size,
25+
at::TensorOptions options = {},
26+
::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) {
27+
if (memory_format.has_value()) {
28+
UNSUPPORTED_FEATURE_IN_PADDLE("`MemoryFormat`")
29+
}
30+
return paddle::experimental::empty(
31+
size._PD_ToPaddleIntArray(),
32+
compat::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
33+
options._PD_GetPlace());
34+
}
35+
at::Tensor ones(at::IntArrayRef size, at::TensorOptions options = {}) {
36+
return paddle::experimental::ones(
37+
size._PD_ToPaddleIntArray(),
38+
compat::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
39+
options._PD_GetPlace());
40+
}
41+
at::Tensor zeros(at::IntArrayRef size, at::TensorOptions options = {}) {
42+
return paddle::experimental::zeros(
43+
size._PD_ToPaddleIntArray(),
44+
compat::_PD_AtenScalarTypeToPhiDataType(options._PD_GetScalarType()),
45+
options._PD_GetPlace());
46+
}
47+
48+
at::Tensor full(at::IntArrayRef size,
49+
const at::Scalar& fill_value,
50+
::std::optional<at::ScalarType> dtype = {},
51+
::std::optional<at::Layout> layout = {},
52+
::std::optional<at::Device> device = {},
53+
::std::optional<bool> pin_memory = {}) {
54+
if (pin_memory.has_value()) {
55+
UNSUPPORTED_FEATURE_IN_PADDLE("`pin_memory` option in full")
56+
}
57+
return paddle::experimental::full(
58+
size._PD_ToPaddleIntArray(),
59+
fill_value,
60+
dtype.has_value() ? compat::_PD_AtenScalarTypeToPhiDataType(*dtype)
61+
: phi::DataType::FLOAT32,
62+
phi::CPUPlace() // TODO(SigureMo): support other places
63+
);
64+
}
65+
66+
} // namespace at

0 commit comments

Comments
 (0)