Skip to content

Commit 011d008

Browse files
committed
[Extension Operants] Extension supports tensor operants
1 parent 746b774 commit 011d008

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

paddle/fluid/pybind/eager_functions.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,18 @@ static PyObject* eager_api_jit_function_call(PyObject* self,
499499
EAGER_CATCH_AND_THROW_RETURN_NULL
500500
}
501501

502-
static PyObject* eager_api_init_eager_and_static_tensor_operants(
503-
PyObject* self, PyObject* args, PyObject* kwargs) {
502+
static PyObject* eager_api_init_tensor_operants(PyObject* self,
503+
PyObject* args,
504+
PyObject* kwargs) {
504505
EAGER_TRY
505506

506507
paddle::OperantsManager::Instance().eager_operants.reset(
507508
new paddle::prim::EagerTensorOperants());
508509
paddle::OperantsManager::Instance().static_operants.reset(
509510
new paddle::prim::StaticTensorOperants());
510-
VLOG(4) << "Initialize eager and static tensor operants successfully";
511+
paddle::OperantsManager::Instance().phi_operants.reset(
512+
new paddle::operants::PhiTensorOperants());
513+
VLOG(4) << "Initialize tensor operants successfully";
511514

512515
RETURN_PY_NONE
513516
EAGER_CATCH_AND_THROW_RETURN_NULL
@@ -1123,9 +1126,8 @@ PyMethodDef variable_functions[] = {
11231126
(PyCFunction)(void (*)(void))eager_api_run_custom_op,
11241127
METH_VARARGS | METH_KEYWORDS,
11251128
NULL},
1126-
{"_init_eager_and_static_tensor_operants",
1127-
(PyCFunction)(void (*)(
1128-
void))eager_api_init_eager_and_static_tensor_operants,
1129+
{"_init_tensor_operants",
1130+
(PyCFunction)(void (*)(void))eager_api_init_tensor_operants,
11291131
METH_VARARGS | METH_KEYWORDS,
11301132
NULL},
11311133
{"tensor_copy",

paddle/utils/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
1717
cc_library(
1818
pybind_util
1919
SRCS pybind.cc
20-
DEPS phi_tensor_raw)
20+
DEPS phi_tensor_raw flags)
2121
endif()

paddle/utils/pybind.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
// limitations under the License.
1414

1515
#include "paddle/utils/pybind.h"
16+
17+
#include "gflags/gflags.h"
1618
#include "paddle/phi/core/enforce.h"
1719

20+
DECLARE_string(tensor_operants_mode);
1821
namespace paddle {
1922
namespace pybind {
2023

@@ -66,5 +69,7 @@ PyObject* ToPyObject(const paddle::experimental::Tensor& value,
6669
return obj;
6770
}
6871

72+
void SwitchTensorOperantsMode() { FLAGS_tensor_operants_mode = "phi"; }
73+
6974
} // namespace pybind
7075
} // namespace paddle

paddle/utils/pybind.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos);
4646
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
4747
bool return_py_none_if_not_initialize = false);
4848

49+
// Internal use only, switch tensor_operants_mode to phi
50+
void SwitchTensorOperantsMode();
51+
4952
} // namespace pybind
5053
} // namespace paddle
5154

@@ -59,6 +62,7 @@ struct type_caster<paddle::experimental::Tensor> {
5962
_("paddle::experimental::Tensor"));
6063

6164
bool load(handle src, bool) {
65+
paddle::pybind::SwitchTensorOperantsMode();
6266
PyObject* obj = src.ptr();
6367
if (paddle::pybind::PyCheckTensor(obj)) {
6468
value = paddle::pybind::CastPyArg2Tensor(obj, 0);

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def remove_flag_if_exists(name):
231231
core.init_glog(sys.argv[0])
232232
# don't init_p2p when in unittest to save time.
233233
core.init_devices()
234-
core.eager._init_eager_and_static_tensor_operants()
234+
core.eager._init_tensor_operants()
235235
core.init_default_kernel_signatures()
236236
core.init_memory_method()
237237

python/paddle/fluid/tests/cpp_extension/custom_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
paddle::Tensor custom_sub(paddle::Tensor x, paddle::Tensor y);
2222

2323
paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) {
24-
return paddle::add(paddle::exp(x), paddle::exp(y));
24+
return x.exp() + y.exp();
2525
}
2626

2727
paddle::Tensor nullable_tensor(bool return_none = false) {

0 commit comments

Comments
 (0)