Skip to content

Commit 683a141

Browse files
authored
add more capi to support stride (#62716)
1 parent fde63d1 commit 683a141

File tree

5 files changed

+213
-0
lines changed

5 files changed

+213
-0
lines changed

paddle/phi/capi/include/c_meta_tensor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,25 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
3939
size_t index,
4040
PD_Status *status);
4141

42+
int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor,
43+
PD_Status *status);
44+
45+
int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor,
46+
size_t index,
47+
PD_Status *status);
48+
4249
bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status);
4350

4451
void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
4552
int64_t ndims,
4653
const int64_t *dims,
4754
PD_Status *status);
4855

56+
void PD_MetaTensorSetStrides(PD_MetaTensor *tensor,
57+
int64_t nstrides,
58+
const int64_t *strides,
59+
PD_Status *status);
60+
4961
void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
5062
PD_DataType dtype,
5163
PD_Status *status);

paddle/phi/capi/include/c_tensor.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ int64_t PD_TensorGetDim(const PD_Tensor *tensor,
4141
size_t index,
4242
PD_Status *status);
4343

44+
int64_t PD_TensorGetNumStrides(const PD_Tensor *tensor, PD_Status *status);
45+
46+
int64_t PD_TensorGetStride(const PD_Tensor *tensor,
47+
size_t index,
48+
PD_Status *status);
49+
4450
void PD_TensorGetLoD(const PD_Tensor *tensor,
4551
PD_List *data,
4652
PD_List *offset,
@@ -52,11 +58,22 @@ bool PD_TensorIsValid(const PD_Tensor *tensor, PD_Status *status);
5258

5359
void *PD_TensorGetHolder(const PD_Tensor *tensor, PD_Status *status);
5460

61+
size_t PD_TensorGetOffset(const PD_Tensor *tensor, PD_Status *status);
62+
5563
void PD_TensorSetDims(PD_Tensor *tensor,
5664
int64_t ndims,
5765
const int64_t *dims,
5866
PD_Status *status);
5967

68+
void PD_TensorSetOffset(PD_Tensor *tensor,
69+
const int64_t offset,
70+
PD_Status *status);
71+
72+
void PD_TensorSetStrides(PD_Tensor *tensor,
73+
int64_t nstrides,
74+
const int64_t *strides,
75+
PD_Status *status);
76+
6077
void PD_TensorSetDataType(PD_Tensor *tensor,
6178
PD_DataType dtype,
6279
PD_Status *status);

paddle/phi/capi/include/wrapper_base.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ inline std::vector<int64_t> PD_TensorGetDims(PD_Tensor* tensor,
7272
return std::vector<int64_t>();
7373
}
7474

75+
inline std::vector<int64_t> PD_TensorGetStrides(PD_Tensor* tensor,
76+
PD_Status* status) {
77+
int64_t nstrides = PD_TensorGetNumStrides(tensor, status);
78+
if (nstrides > 0) {
79+
std::vector<int64_t> shape(nstrides);
80+
for (int64_t i = 0; i < nstrides; ++i) {
81+
shape[i] = PD_TensorGetStride(tensor, i, status);
82+
}
83+
return shape;
84+
}
85+
return std::vector<int64_t>();
86+
}
87+
7588
inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
7689
PD_Status* status) {
7790
int64_t ndims = PD_MetaTensorGetNumDims(tensor, status);
@@ -85,6 +98,19 @@ inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
8598
return std::vector<int64_t>();
8699
}
87100

101+
inline std::vector<int64_t> PD_MetaTensorGetStrides(PD_MetaTensor* tensor,
102+
PD_Status* status) {
103+
int64_t nstrides = PD_MetaTensorGetNumStrides(tensor, status);
104+
if (nstrides > 0) {
105+
std::vector<int64_t> shape(nstrides);
106+
for (int64_t i = 0; i < nstrides; ++i) {
107+
shape[i] = PD_MetaTensorGetStride(tensor, i, status);
108+
}
109+
return shape;
110+
}
111+
return std::vector<int64_t>();
112+
}
113+
88114
template <typename T>
89115
class WrapperBase {
90116
public:
@@ -134,13 +160,27 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
134160
return holder;
135161
}
136162

163+
size_t offset() const {
164+
C_Status status;
165+
auto offset = PD_TensorGetOffset(raw_data(), &status);
166+
PD_CHECK_STATUS(status);
167+
return offset;
168+
}
169+
137170
std::vector<int64_t> dims() const {
138171
C_Status status;
139172
auto dimension = PD_TensorGetDims(raw_data(), &status);
140173
PD_CHECK_STATUS(status);
141174
return dimension;
142175
}
143176

177+
std::vector<int64_t> strides() const {
178+
C_Status status;
179+
auto strides = PD_TensorGetStrides(raw_data(), &status);
180+
PD_CHECK_STATUS(status);
181+
return strides;
182+
}
183+
144184
PD_DataType dtype() const {
145185
C_Status status;
146186
auto data_type = PD_TensorGetPDDataType(raw_data(), &status);
@@ -207,6 +247,18 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
207247
PD_CHECK_STATUS(status);
208248
}
209249

250+
void set_offset(const int64_t& offset) {
251+
C_Status status;
252+
PD_TensorSetOffset(raw_data(), offset, &status);
253+
PD_CHECK_STATUS(status);
254+
}
255+
256+
void set_strides(const std::vector<int64_t>& strides) {
257+
C_Status status;
258+
PD_TensorSetStrides(raw_data(), strides.size(), strides.data(), &status);
259+
PD_CHECK_STATUS(status);
260+
}
261+
210262
void set_dtype(PD_DataType data_type) {
211263
C_Status status;
212264
PD_TensorSetDataType(raw_data(), data_type, &status);
@@ -513,6 +565,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
513565
return dimension;
514566
}
515567

568+
std::vector<int64_t> strides() const {
569+
C_Status status;
570+
auto strides = PD_MetaTensorGetStrides(raw_data(), &status);
571+
PD_CHECK_STATUS(status);
572+
return strides;
573+
}
574+
516575
PD_DataType dtype() const {
517576
C_Status status;
518577
auto data_type = PD_MetaTensorGetPDDataType(raw_data(), &status);
@@ -540,6 +599,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
540599
PD_CHECK_STATUS(status);
541600
}
542601

602+
void set_strides(const std::vector<int64_t>& strides) {
603+
C_Status status;
604+
PD_MetaTensorSetStrides(
605+
raw_data(), strides.size(), strides.data(), &status);
606+
PD_CHECK_STATUS(status);
607+
}
608+
543609
void set_dtype(PD_DataType data_type) {
544610
C_Status status;
545611
PD_MetaTensorSetDataType(raw_data(), data_type, &status);

paddle/phi/capi/lib/c_meta_tensor.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
8888
return cc_tensor->dims()[index];
8989
}
9090

91+
int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor,
92+
PD_Status *status) {
93+
if (status) {
94+
if (!tensor) {
95+
*status = C_FAILED;
96+
return 0;
97+
}
98+
*status = C_SUCCESS;
99+
}
100+
101+
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
102+
return cc_tensor->strides().size();
103+
}
104+
105+
int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor,
106+
size_t index,
107+
PD_Status *status) {
108+
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
109+
110+
if (status) {
111+
if (!tensor || index >= static_cast<size_t>(cc_tensor->strides().size())) {
112+
*status = C_FAILED;
113+
return 0;
114+
}
115+
*status = C_SUCCESS;
116+
}
117+
118+
return cc_tensor->strides()[index];
119+
}
120+
91121
bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status) {
92122
if (status) {
93123
if (!tensor) {
@@ -117,6 +147,22 @@ void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
117147
cc_tensor->set_dims(common::make_ddim(shape));
118148
}
119149

150+
void PD_MetaTensorSetStrides(PD_MetaTensor *tensor,
151+
int64_t nstrides,
152+
const int64_t *strides,
153+
PD_Status *status) {
154+
if (status) {
155+
if (!tensor) {
156+
*status = C_FAILED;
157+
return;
158+
}
159+
*status = C_SUCCESS;
160+
}
161+
auto cc_tensor = reinterpret_cast<phi::MetaTensor *>(tensor);
162+
std::vector<int> shape(strides, strides + nstrides);
163+
cc_tensor->set_strides(common::make_ddim(shape));
164+
}
165+
120166
void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
121167
PD_DataType dtype,
122168
PD_Status *status) {

paddle/phi/capi/lib/c_tensor.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,35 @@ int64_t PD_TensorGetDim(const PD_Tensor* tensor,
111111
return cc_tensor->dims()[index];
112112
}
113113

114+
int64_t PD_TensorGetNumStrides(const PD_Tensor* tensor, PD_Status* status) {
115+
if (status) {
116+
if (!tensor) {
117+
*status = C_FAILED;
118+
return 0;
119+
}
120+
*status = C_SUCCESS;
121+
}
122+
123+
auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);
124+
return cc_tensor->strides().size();
125+
}
126+
127+
int64_t PD_TensorGetStride(const PD_Tensor* tensor,
128+
size_t index,
129+
PD_Status* status) {
130+
auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);
131+
132+
if (status) {
133+
if (!tensor || index >= static_cast<size_t>(cc_tensor->strides().size())) {
134+
*status = C_FAILED;
135+
return 0;
136+
}
137+
*status = C_SUCCESS;
138+
}
139+
140+
return cc_tensor->strides()[index];
141+
}
142+
114143
void PD_TensorGetLoD(const PD_Tensor* tensor,
115144
PD_List* data,
116145
PD_List* offset,
@@ -185,6 +214,19 @@ void* PD_TensorGetHolder(const PD_Tensor* tensor, PD_Status* status) {
185214
return cc_tensor->Holder().get();
186215
}
187216

217+
size_t PD_TensorGetOffset(const PD_Tensor* tensor, PD_Status* status) {
218+
if (status) {
219+
if (!tensor) {
220+
*status = C_FAILED;
221+
return 0;
222+
}
223+
*status = C_SUCCESS;
224+
}
225+
226+
auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);
227+
return cc_tensor->offset();
228+
}
229+
188230
void PD_TensorSetDims(PD_Tensor* tensor,
189231
int64_t ndims,
190232
const int64_t* dims,
@@ -201,6 +243,36 @@ void PD_TensorSetDims(PD_Tensor* tensor,
201243
cc_tensor->Resize(common::make_ddim(shape));
202244
}
203245

246+
void PD_TensorSetOffset(PD_Tensor* tensor,
247+
const int64_t offset,
248+
PD_Status* status) {
249+
if (status) {
250+
if (!tensor) {
251+
*status = C_FAILED;
252+
return;
253+
}
254+
*status = C_SUCCESS;
255+
}
256+
auto cc_tensor = reinterpret_cast<phi::DenseTensor*>(tensor);
257+
cc_tensor->set_offset(offset);
258+
}
259+
260+
void PD_TensorSetStrides(PD_Tensor* tensor,
261+
int64_t nstrides,
262+
const int64_t* strides,
263+
PD_Status* status) {
264+
if (status) {
265+
if (!tensor) {
266+
*status = C_FAILED;
267+
return;
268+
}
269+
*status = C_SUCCESS;
270+
}
271+
auto cc_tensor = reinterpret_cast<phi::DenseTensor*>(tensor);
272+
std::vector<int> shape(strides, strides + nstrides);
273+
cc_tensor->set_strides(common::make_ddim(shape));
274+
}
275+
204276
void PD_TensorSetDataType(PD_Tensor* tensor,
205277
PD_DataType dtype,
206278
PD_Status* status) {

0 commit comments

Comments
 (0)