Skip to content

Commit e5b4340

Browse files
authored
[SOT][Faster Guard] support NumpyDtypeMatchGuard (#71900)
* [SOT][Faster Guard] support `NumpyDtypeMatchGuard` * fix review * fix review and test build * test build * fix build error * use pyobj * add todo
1 parent f28e8c3 commit e5b4340

File tree

5 files changed

+66
-2
lines changed

5 files changed

+66
-2
lines changed

paddle/fluid/pybind/jit.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ void BindGuard(pybind11::module *m) {
108108
std::shared_ptr<InstanceCheckGuard>>(
109109
*m, "InstanceCheckGuard", R"DOC(InstanceCheckGuard Class.)DOC")
110110
.def(py::init<const py::object &>(), py::arg("isinstance_obj"));
111+
py::class_<NumpyDtypeMatchGuard,
112+
GuardBase,
113+
std::shared_ptr<NumpyDtypeMatchGuard>>(
114+
*m, "NumpyDtypeMatchGuard", R"DOC(NumpyDtypeMatchGuard Class.)DOC")
115+
.def(py::init<const py::object &>(), py::arg("dtype"));
111116

112117
m->def(
113118
"merge_guard",

paddle/fluid/pybind/sot/guards.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919

2020
#include <Python.h>
2121
#include <frameobject.h>
22+
#include "pybind11/numpy.h"
2223

2324
#if !defined(PyObject_CallOneArg) && !PY_3_9_PLUS
2425
static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
@@ -133,4 +134,20 @@ bool InstanceCheckGuard::check(PyObject* value) {
133134
return PyObject_IsInstance(value, expected_);
134135
}
135136

137+
bool NumpyDtypeMatchGuard::check(PyObject* value) {
138+
if (value == nullptr) {
139+
return false;
140+
}
141+
142+
// TODO(dev): encountered a compilation error: "declared with greater
143+
// visibility than the type of its field", so had to put the conversion here
144+
py::dtype expected_dtype = py::cast<py::dtype>(expected_);
145+
146+
if (py::isinstance<py::array>(value)) {
147+
return py::cast<py::array>(value).dtype().is(expected_dtype);
148+
}
149+
150+
return expected_dtype.equal(py::handle(value).get_type());
151+
}
152+
136153
#endif

paddle/fluid/pybind/sot/guards.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,19 @@ class InstanceCheckGuard : public GuardBase {
202202
PyObject* expected_;
203203
};
204204

205+
class NumpyDtypeMatchGuard : public GuardBase {
206+
public:
207+
explicit NumpyDtypeMatchGuard(const py::object& dtype)
208+
: expected_(dtype.ptr()) {
209+
Py_INCREF(expected_);
210+
}
211+
212+
~NumpyDtypeMatchGuard() override { Py_DECREF(expected_); }
213+
214+
bool check(PyObject* value) override;
215+
216+
private:
217+
PyObject* expected_;
218+
};
219+
205220
#endif

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,16 +1331,20 @@ def getattr(self, name: str, default=None):
13311331
def make_stringified_guard(self) -> list[StringifiedExpression]:
13321332
frame_value_tracer = self.tracker.trace_value_from_frame()
13331333

1334-
dtype_guard = StringifiedExpression(
1334+
dtype_guard = FasterStringifiedExpression(
13351335
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
1336+
paddle.framework.core.NumpyDtypeMatchGuard(
1337+
self.get_py_value().dtype
1338+
),
13361339
[frame_value_tracer],
13371340
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
13381341
)
13391342

13401343
return [
13411344
dtype_guard,
1342-
StringifiedExpression(
1345+
FasterStringifiedExpression(
13431346
f"{{}} == {NumpyVariable.format_number(self.get_py_value())}",
1347+
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
13441348
[frame_value_tracer],
13451349
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
13461350
),

test/sot/test_faster_guard.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import unittest
1818
from collections import OrderedDict
1919

20+
import numpy as np
21+
2022
import paddle
2123

2224

@@ -125,6 +127,27 @@ def test_id_match_guard(self):
125127
self.assertTrue(guard_id.check(layer))
126128
self.assertFalse(guard_id.check(paddle.nn.Linear(10, 10)))
127129

130+
def test_numpy_dtype_match_guard(self):
131+
np_array = np.array(1, dtype=np.int32)
132+
guard_numpy_dtype = paddle.framework.core.NumpyDtypeMatchGuard(
133+
np_array.dtype
134+
)
135+
self.assertTrue(guard_numpy_dtype.check(np_array))
136+
self.assertTrue(guard_numpy_dtype.check(np.array(1, dtype=np.int32)))
137+
self.assertTrue(guard_numpy_dtype.check(np.int32()))
138+
self.assertFalse(guard_numpy_dtype.check(np.array(1, dtype=np.int64)))
139+
self.assertFalse(guard_numpy_dtype.check(np.float32()))
140+
self.assertFalse(guard_numpy_dtype.check(np.bool_()))
141+
142+
np_bool = np.bool_(1)
143+
guard_numpy_bool_dtype = paddle.framework.core.NumpyDtypeMatchGuard(
144+
np_bool.dtype
145+
)
146+
self.assertTrue(guard_numpy_bool_dtype.check(np.bool_()))
147+
self.assertTrue(
148+
guard_numpy_bool_dtype.check(np.array(1, dtype=np.bool_))
149+
)
150+
128151

129152
class TestFasterGuardGroup(unittest.TestCase):
130153
def test_guard_group(self):

0 commit comments

Comments
 (0)