Skip to content

Commit 013a19b

Browse files
committed
export some class as PADDLE_API
1 parent 043128a commit 013a19b

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

paddle/phi/api/include/torch_compat_runtime.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <iostream>
2222
#include <optional>
2323
#include <variant>
24+
#include "paddle/common/macros.h"
2425
#include "paddle/phi/api/include/api.h"
2526
#include "paddle/phi/api/include/tensor.h"
2627
#include "paddle/phi/common/bfloat16.h"
@@ -91,7 +92,7 @@ class ArrayRef {
9192
};
9293
using IntArrayRef = ArrayRef<int64_t>;
9394

94-
enum class MemoryFormat : int8_t {
95+
enum class PADDLE_API MemoryFormat : int8_t {
9596
Contiguous,
9697
Preserve,
9798
ChannelsLast,
@@ -122,15 +123,15 @@ using BFloat16 = c10::BFloat16;
122123
_(uint16_t, UINT16, UInt16) \
123124
_(uint32_t, UINT32, UInt32)
124125

125-
enum class ScalarType : int8_t {
126+
enum class PADDLE_API ScalarType : int8_t {
126127
#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n,
127128
FORALL_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_ENUM_VAL_)
128129
#undef DEFINE_ENUM_ST_ENUM_VAL_
129130
Undefined,
130131
NumOptions
131132
};
132133

133-
struct TensorOptions {
134+
struct PADDLE_API TensorOptions {
134135
TensorOptions()
135136
: requires_grad_(false),
136137
pinned_memory_(false),
@@ -273,28 +274,18 @@ using Dtype = at::ScalarType;
273274

274275
void compiling_test() {
275276
// Example usage of the Tensor class
276-
std::cout << "111111";
277277
at::Tensor a = at::ones({2, 3}, at::TensorOptions());
278-
std::cout << "222222";
279278
at::Tensor b = at::full({2, 3}, 1, at::ScalarType::Float);
280-
std::cout << "333333";
281279
double c = 10;
282280
at::Tensor a_contig = a.contiguous();
283-
std::cout << "444444";
284281
at::Tensor b_contig = b.contiguous();
285-
std::cout << "555555";
286282
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
287-
std::cout << "666666";
288283
const float* a_ptr = a_contig.data_ptr<float>();
289-
std::cout << "777777";
290284
const float* b_ptr = b_contig.data_ptr<float>();
291-
std::cout << "888888";
292285
float* result_ptr = result.data_ptr<float>();
293-
std::cout << "999999";
294286
for (int64_t i = 0; i < a_contig.numel(); i++) {
295287
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
296288
}
297-
std::cout << "000000";
298289
// Show result
299290
for (int64_t i = 0; i < a_contig.numel(); i++) {
300291
std::cout << "Result[" << i << "] = " << a_ptr[i] * b_ptr[i] + c

0 commit comments

Comments
 (0)