Skip to content

Commit b772a05

Browse files
committed
py::array -> py::object
1 parent 69e943f commit b772a05

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

paddle/fluid/pybind/jit.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ void BindGuard(pybind11::module *m) {
121121
*m,
122122
"NumPyArrayValueMatchGuard",
123123
R"DOC(NumPyArrayValueMatchGuard Class.)DOC")
124-
.def(py::init<const py::array &>(), py::arg("array"));
124+
.def(py::init<const py::object &>(), py::arg("array"));
125125

126126
m->def(
127127
"merge_guard",

paddle/fluid/pybind/sot/guards.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ bool NumPyArrayValueMatchGuard::check(PyObject* value) {
158158
}
159159

160160
py::object py_value = py::cast<py::object>(value);
161-
return expected_.attr("__eq__")(py_value).attr("all")().cast<bool>();
161+
return py::cast<py::object>(expected_)
162+
.attr("__eq__")(py_value)
163+
.attr("all")()
164+
.cast<bool>();
162165
}
163166

164167
PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; }

paddle/fluid/pybind/sot/guards.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,17 @@ class NumpyDtypeMatchGuard : public GuardBase {
223223

224224
class NumPyArrayValueMatchGuard : public GuardBase {
225225
public:
226-
explicit NumPyArrayValueMatchGuard(const py::array& array)
227-
: expected_(array) {}
226+
explicit NumPyArrayValueMatchGuard(const py::object& array)
227+
: expected_(array.ptr()) {
228+
Py_INCREF(expected_);
229+
}
230+
231+
~NumPyArrayValueMatchGuard() override { Py_DECREF(expected_); }
228232

229233
bool check(PyObject* value) override;
230234

231235
private:
232-
py::array expected_;
236+
PyObject* expected_;
233237
};
234238

235239
class GuardTreeNode {};

0 commit comments

Comments
 (0)