Skip to content

Commit ff4ec23

Browse files
authored
[Cpp Extension] Support optional types (#50764)
* [Cpp Extension] Support optional type * fix custom_extension.cc
1 parent dca3a09 commit ff4ec23

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

paddle/utils/pybind.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/phi/api/include/tensor.h"
18+
#include "paddle/utils/optional.h"
1819
#include "pybind11/pybind11.h"
1920
#include "pybind11/stl.h"
2021

@@ -73,5 +74,12 @@ struct type_caster<paddle::experimental::Tensor> {
7374
src, true /* return_py_none_if_not_initialize */));
7475
}
7576
};
77+
78+
// Pybind11 bindings for optional types.
79+
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
80+
template <typename T>
81+
struct type_caster<paddle::optional<T>> : optional_caster<paddle::optional<T>> {
82+
};
83+
7684
} // namespace detail
7785
} // namespace pybind11

python/paddle/fluid/tests/cpp_extension/custom_extension.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,20 @@ paddle::Tensor nullable_tensor(bool return_none = false) {
3232
return t;
3333
}
3434

35+
paddle::optional<paddle::Tensor> optional_tensor(bool return_option = false) {
36+
paddle::optional<paddle::Tensor> t;
37+
if (!return_option) {
38+
t = paddle::ones({2, 2});
39+
}
40+
return t;
41+
}
42+
3543
PYBIND11_MODULE(custom_cpp_extension, m) {
3644
m.def("custom_add", &custom_add, "exp(x) + exp(y)");
3745
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
3846
m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None");
47+
m.def(
48+
"optional_tensor", &optional_tensor, "returned Tensor might be optional");
3949

4050
py::class_<Power>(m, "Power")
4151
.def(py::init<int, int>())

python/paddle/fluid/tests/cpp_extension/test_cpp_extension_jit.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def tearDown(self):
6767
def test_cpp_extension(self):
6868
self._test_extension_function()
6969
self._test_extension_class()
70+
self._test_nullable_tensor()
71+
self._test_optional_tensor()
7072

7173
def _test_extension_function(self):
7274
for dtype in self.dtypes:
@@ -104,6 +106,30 @@ def _test_extension_class(self):
104106
atol=1e-5,
105107
)
106108

109+
def _test_nullable_tensor(self):
110+
x = custom_cpp_extension.nullable_tensor(True)
111+
assert x is None, "Return None when input parameter return_none = True"
112+
x = custom_cpp_extension.nullable_tensor(False).numpy()
113+
x_np = np.ones(shape=[2, 2])
114+
np.testing.assert_array_equal(
115+
x,
116+
x_np,
117+
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
118+
)
119+
120+
def _test_optional_tensor(self):
121+
x = custom_cpp_extension.optional_tensor(True)
122+
assert (
123+
x is None
124+
), "Return None when input parameter return_option = True"
125+
x = custom_cpp_extension.optional_tensor(False).numpy()
126+
x_np = np.ones(shape=[2, 2])
127+
np.testing.assert_array_equal(
128+
x,
129+
x_np,
130+
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
131+
)
132+
107133

108134
if __name__ == '__main__':
109135
unittest.main()

python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_cpp_extension(self):
149149
self._test_extension_function_mixed()
150150
self._test_extension_class()
151151
self._test_nullable_tensor()
152+
self._test_optional_tensor()
152153
# Custom op
153154
self._test_static()
154155
self._test_dynamic()
@@ -227,6 +228,21 @@ def _test_nullable_tensor(self):
227228
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
228229
)
229230

231+
def _test_optional_tensor(self):
232+
import custom_cpp_extension
233+
234+
x = custom_cpp_extension.optional_tensor(True)
235+
assert (
236+
x is None
237+
), "Return None when input parameter return_option = True"
238+
x = custom_cpp_extension.optional_tensor(False).numpy()
239+
x_np = np.ones(shape=[2, 2])
240+
np.testing.assert_array_equal(
241+
x,
242+
x_np,
243+
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
244+
)
245+
230246
def _test_static(self):
231247
import mix_relu_extension
232248

0 commit comments

Comments
 (0)