File tree 3 files changed +12
-5
lines changed
3 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -121,7 +121,7 @@ void BindGuard(pybind11::module *m) {
121
121
*m,
122
122
" NumPyArrayValueMatchGuard" ,
123
123
R"DOC( NumPyArrayValueMatchGuard Class.)DOC" )
124
- .def (py::init<const py::array &>(), py::arg (" array" ));
124
+ .def (py::init<const py::object &>(), py::arg (" array" ));
125
125
126
126
m->def (
127
127
" merge_guard" ,
Original file line number Diff line number Diff line change @@ -158,7 +158,10 @@ bool NumPyArrayValueMatchGuard::check(PyObject* value) {
158
158
}
159
159
160
160
py::object py_value = py::cast<py::object>(value);
161
- return expected_.attr (" __eq__" )(py_value).attr (" all" )().cast <bool >();
161
+ return py::cast<py::object>(expected_)
162
+ .attr (" __eq__" )(py_value)
163
+ .attr (" all" )()
164
+ .cast <bool >();
162
165
}
163
166
164
167
PyObject* ConstantExprNode::eval (FrameProxy* frame) { return value_ptr_; }
Original file line number Diff line number Diff line change @@ -223,13 +223,17 @@ class NumpyDtypeMatchGuard : public GuardBase {
223
223
224
224
class NumPyArrayValueMatchGuard : public GuardBase {
225
225
public:
226
- explicit NumPyArrayValueMatchGuard (const py::array& array)
227
- : expected_(array) {}
226
+ explicit NumPyArrayValueMatchGuard (const py::object& array)
227
+ : expected_(array.ptr()) {
228
+ Py_INCREF (expected_);
229
+ }
230
+
231
+ ~NumPyArrayValueMatchGuard () override { Py_DECREF (expected_); }
228
232
229
233
bool check (PyObject* value) override ;
230
234
231
235
private:
232
- py::array expected_;
236
+ PyObject* expected_;
233
237
};
234
238
235
239
class GuardTreeNode {};
You can’t perform that action at this time.
0 commit comments