Skip to content

Commit 234ce2c

Browse files
test changes
1 parent 0e121f8 commit 234ce2c

File tree

6 files changed

+151
-22
lines changed

6 files changed

+151
-22
lines changed

src/tensor_array/_core/tensor_bind.cc

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,46 @@ PYBIND11_MODULE(tensor2, m)
186186
pybind11::arg("dtype") = S_INT_32
187187
);
188188

189+
m.def(
190+
"add",
191+
&tensor_array::value::add,
192+
pybind11::arg("value_1"),
193+
pybind11::arg("value_2")
194+
);
195+
196+
m.def(
197+
"multiply",
198+
&tensor_array::value::multiply,
199+
pybind11::arg("value_1"),
200+
pybind11::arg("value_2")
201+
);
202+
203+
m.def(
204+
"divide",
205+
&tensor_array::value::divide,
206+
pybind11::arg("value_1"),
207+
pybind11::arg("value_2")
208+
);
209+
210+
m.def(
211+
"matmul",
212+
&tensor_array::value::matmul,
213+
pybind11::arg("value_1"),
214+
pybind11::arg("value_2")
215+
);
216+
217+
m.def(
218+
"condition",
219+
&tensor_array::value::condition,
220+
pybind11::arg("condition_value"),
221+
pybind11::arg("value_if_true"),
222+
pybind11::arg("value_if_false")
223+
);
224+
189225
pybind11::class_<Tensor>(m, "Tensor")
190226
.def(pybind11::init())
191227
.def(pybind11::init(&tensor_copying))
228+
.def(pybind11::init(&convert_numpy_to_tensor_base<int>))
192229
.def(pybind11::init(&convert_numpy_to_tensor_base<float>))
193230
.def(pybind11::self + pybind11::self)
194231
.def(pybind11::self - pybind11::self)
@@ -207,33 +244,27 @@ PYBIND11_MODULE(tensor2, m)
207244
.def(+pybind11::self)
208245
.def(-pybind11::self)
209246
.def(hash(pybind11::self))
210-
.def("transpose", &Tensor::transpose)
211-
.def("calc_grad", &Tensor::calc_grad)
212-
.def("get_grad", &Tensor::get_grad)
213-
.def("sin", &Tensor::sin)
214-
.def("sin", &Tensor::sin)
215-
.def("cos", &Tensor::cos)
216-
.def("tan", &Tensor::tan)
217-
.def("sinh", &Tensor::sinh)
218-
.def("cosh", &Tensor::cosh)
219-
.def("tanh", &Tensor::tanh)
220-
.def("log", &Tensor::log)
221-
.def("clone", &Tensor::clone)
247+
.def("transpose", &tensor_array::value::Tensor::transpose)
248+
.def("calc_grad", &tensor_array::value::Tensor::calc_grad)
249+
.def("get_grad", &tensor_array::value::Tensor::get_grad)
250+
.def("sin", &tensor_array::value::Tensor::sin)
251+
.def("cos", &tensor_array::value::Tensor::cos)
252+
.def("tan", &tensor_array::value::Tensor::tan)
253+
.def("sinh", &tensor_array::value::Tensor::sinh)
254+
.def("cosh", &tensor_array::value::Tensor::cosh)
255+
.def("tanh", &tensor_array::value::Tensor::tanh)
256+
.def("log", &tensor_array::value::Tensor::log)
257+
.def("clone", &tensor_array::value::Tensor::clone)
222258
.def("cast", &tensor_cast_1)
223-
.def("add", &add)
224-
.def("multiply", &multiply)
225-
.def("divide", &divide)
226-
.def("matmul", &matmul)
227-
.def("condition", &condition)
228259
.def("numpy", &convert_tensor_to_numpy)
229260
.def("shape", &tensor_shape)
230261
.def("dtype", &tensor_type)
231262
.def("__getitem__", &python_index)
232263
.def("__getitem__", &python_slice)
233264
.def("__getitem__", &python_tuple_slice)
234265
.def("__len__", &python_len)
235-
.def("__matmul__", &matmul)
236-
.def("__rmatmul__", &matmul)
266+
.def("__matmul__", &tensor_array::value::matmul)
267+
.def("__rmatmul__", &tensor_array::value::matmul)
237268
.def("__repr__", &tensor_to_string)
238269
.def("__copy__", &tensor_copying);
239270
}

src/tensor_array/core/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from tensor_array.core.tensor2 import Tensor
2-
from tensor_array.core.tensor2 import zeros
3-
from tensor_array.core.tensor2 import DataType
1+
from .tensor import Tensor
2+
from .constants import *
3+
from .datatypes import DataTypes
4+
from .operator import *

src/tensor_array/core/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from tensor_array.core.tensor2 import zeros as zerosWrapper
2+
from .tensor import Tensor
3+
from .datatypes import DataTypes
4+
5+
def zeros(shape : Tensor, dtype : DataTypes = DataTypes.S_INT_32):
6+
return zerosWrapper(shape, dtype)

src/tensor_array/core/datatypes.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from tensor_array.core.tensor2 import DataType as DataTypeWrapper
2+
from enum import Enum
3+
4+
class DataTypes(Enum):
5+
BOOL = DataTypeWrapper.BOOL
6+
S_INT_8 = DataTypeWrapper.S_INT_8
7+
S_INT_16 = DataTypeWrapper.S_INT_16
8+
S_INT_32 = DataTypeWrapper.S_INT_32
9+
S_INT_64 = DataTypeWrapper.S_INT_64
10+
FLOAT = DataTypeWrapper.FLOAT
11+
DOUBLE = DataTypeWrapper.DOUBLE
12+
HALF = DataTypeWrapper.HALF
13+
BFLOAT16 = DataTypeWrapper.BFLOAT16
14+
U_INT_8 = DataTypeWrapper.U_INT_8
15+
U_INT_16 = DataTypeWrapper.U_INT_16
16+
U_INT_32 = DataTypeWrapper.U_INT_32
17+
U_INT_64 = DataTypeWrapper.U_INT_64

src/tensor_array/core/operator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .tensor import Tensor
2+
from tensor_array.core.tensor2 import add as addWrapper
3+
from tensor_array.core.tensor2 import multiply as multiplyWrapper
4+
from tensor_array.core.tensor2 import divide as divideWrapper
5+
from tensor_array.core.tensor2 import matmul as matmulWrapper
6+
from tensor_array.core.tensor2 import condition as conditionWrapper
7+
8+
def add(value_1 : Tensor, value_2 : Tensor):
9+
return addWrapper(value_1, value_2)
10+
11+
def divide(value_1 : Tensor, value_2 : Tensor):
12+
return multiplyWrapper(value_1, value_2)
13+
14+
def multiply(value_1 : Tensor, value_2 : Tensor):
15+
return divideWrapper(value_1, value_2)
16+
17+
def matmul(value_1 : Tensor, value_2 : Tensor):
18+
return matmulWrapper(value_1, value_2)
19+
20+
def condition(condition_value : Tensor, value_if_true : Tensor, value_if_false : Tensor):
21+
return conditionWrapper(condition_value, value_if_true, value_if_false)
22+

src/tensor_array/core/tensor.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from tensor_array.core.tensor2 import Tensor as TensorWrapper
2+
from .datatypes import DataTypes
3+
4+
class Tensor(TensorWrapper):
5+
def __init__(self, *args, **kwargs):
6+
super().__init__(*args, **kwargs)
7+
8+
def transpose(self, dim0: int, dim1: int, isDevive: bool):
9+
return super().transpose(dim0, dim1, isDevive)
10+
11+
def calc_grad(self):
12+
super().calc_grad()
13+
14+
def get_grad(self):
15+
return super().get_grad()
16+
17+
def sin(self):
18+
return super().sin()
19+
20+
def cos(self):
21+
return super().cos()
22+
23+
def tan(self):
24+
return super().tan()
25+
26+
def sinh(self):
27+
return super().sinh()
28+
29+
def cosh(self):
30+
return super().cosh()
31+
32+
def tanh(self):
33+
return super().tanh()
34+
35+
def log(self):
36+
return super().log()
37+
38+
def clone(self):
39+
return super().clone()
40+
41+
def cast(self, dtype: DataTypes):
42+
return super().cast(dtype)
43+
44+
def numpy(self):
45+
return super().numpy()
46+
47+
def shape(self):
48+
return super().shape()
49+
50+
def dtype(self):
51+
return super().dtype()
52+

0 commit comments

Comments
 (0)