Skip to content

Commit 95de00e

Browse files
gouzilCopilot
andauthored
[SOT][Faster Guard] add InstanceCheckGuard (#69975)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 53f92d0 commit 95de00e

19 files changed

+122
-18
lines changed

paddle/fluid/pybind/jit.cc

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ void BindGuard(pybind11::module *m) {
106106
py::class_<RangeMatchGuard, GuardBase, std::shared_ptr<RangeMatchGuard>>(
107107
*m, "RangeMatchGuard", R"DOC(RangeMatchGuard Class.)DOC")
108108
.def(py::init<const py::object &>(), py::arg("range_obj"));
109+
py::class_<InstanceCheckGuard,
110+
GuardBase,
111+
std::shared_ptr<InstanceCheckGuard>>(
112+
*m, "InstanceCheckGuard", R"DOC(InstanceCheckGuard Class.)DOC")
113+
.def(py::init<const py::object &>(), py::arg("isinstance_obj"));
109114

110115
m->def(
111116
"merge_guard",

paddle/fluid/pybind/sot/guards.cc

+4
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,8 @@ bool LayerMatchGuard::check(PyObject* value) {
123123
return (training == Py_True) == training_;
124124
}
125125

126+
bool InstanceCheckGuard::check(PyObject* value) {
127+
return PyObject_IsInstance(value, expected_);
128+
}
129+
126130
#endif

paddle/fluid/pybind/sot/guards.h

+15
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,19 @@ class RangeMatchGuard : public GuardGroup {
200200
}
201201
};
202202

203+
class InstanceCheckGuard : public GuardBase {
204+
public:
205+
explicit InstanceCheckGuard(const py::object& py_type)
206+
: expected_(py_type.ptr()) {
207+
Py_INCREF(expected_);
208+
}
209+
210+
~InstanceCheckGuard() override { Py_DECREF(expected_); }
211+
212+
bool check(PyObject* value) override;
213+
214+
private:
215+
PyObject* expected_;
216+
};
217+
203218
#endif

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1069,8 +1069,9 @@ def get_py_value(self, allow_tensor=False):
10691069
def make_stringified_guard(self) -> list[StringifiedExpression]:
10701070
frame_value_tracer = self.tracker.trace_value_from_frame()
10711071
result = [
1072-
StringifiedExpression(
1073-
"isinstance({}, slice)",
1072+
FasterStringifiedExpression(
1073+
"id(type({{}})) == id(slice)",
1074+
paddle.framework.core.TypeMatchGuard(slice),
10741075
[frame_value_tracer],
10751076
frame_value_tracer.free_vars,
10761077
),

python/paddle/jit/sot/opcode_translator/executor/variables/container.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919
from functools import reduce
2020
from typing import TYPE_CHECKING, Any
2121

22+
import paddle
23+
2224
from ....utils import ConstTypes
2325
from ....utils.exceptions import FallbackError, InnerError
2426
from ..dispatcher import Dispatcher
25-
from ..guard import StringifiedExpression, check_guard
27+
from ..guard import (
28+
FasterStringifiedExpression,
29+
StringifiedExpression,
30+
check_guard,
31+
)
2632
from ..mutable_data import MutableDictLikeData, MutableListLikeData
2733
from ..tracker import (
2834
ConstTracker,
@@ -74,13 +80,23 @@ def bool(self):
7480
def make_stringified_guard(self) -> list[StringifiedExpression]:
7581
frame_value_tracer = self.tracker.trace_value_from_frame()
7682

77-
type_guard = StringifiedExpression(
78-
f"isinstance({{}}, {self.get_py_type().__name__})",
79-
[frame_value_tracer],
80-
frame_value_tracer.free_vars,
81-
)
82-
len_guard = StringifiedExpression(
83+
if self.get_py_type() is dict:
84+
type_guard = FasterStringifiedExpression(
85+
f"isinstance({{}}, {self.get_py_type().__name__})",
86+
paddle.framework.core.InstanceCheckGuard(self.get_py_type()),
87+
[frame_value_tracer],
88+
frame_value_tracer.free_vars,
89+
)
90+
else:
91+
type_guard = FasterStringifiedExpression(
92+
f"id(type({{}})) == {id(self.get_py_type())}",
93+
paddle.framework.core.TypeMatchGuard(self.get_py_type()),
94+
[frame_value_tracer],
95+
frame_value_tracer.free_vars,
96+
)
97+
len_guard = FasterStringifiedExpression(
8398
f"len({{}}) == {len(self.init_value)}",
99+
paddle.framework.core.LengthMatchGuard(len(self.init_value)),
84100
[frame_value_tracer],
85101
frame_value_tracer.free_vars,
86102
)

test/sot/test_04_list.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import unittest
2323

24-
from test_case_base import TestCaseBase
24+
from test_case_base import TestCaseBase, test_with_faster_guard
2525

2626
import paddle
2727
from paddle.jit.sot.psdb import check_no_breakgraph
@@ -368,6 +368,7 @@ def test_list_add(self):
368368
def test_list_inplace_add(self):
369369
self.assert_results(list_inplace_add)
370370

371+
@test_with_faster_guard
371372
def test_list_extend_range(self):
372373
self.assert_results(list_extend_range, paddle.to_tensor([1, 2]))
373374

test/sot/test_05_dict.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import unittest
2020

21-
from test_case_base import TestCaseBase
21+
from test_case_base import TestCaseBase, test_with_faster_guard
2222

2323
import paddle
2424
from paddle.jit.sot.psdb import check_no_breakgraph
@@ -253,6 +253,7 @@ def test_construct(self):
253253
def test_dict_noargs(self):
254254
self.assert_results(dict_no_arguments)
255255

256+
@test_with_faster_guard
256257
def test_dict_fromkeys(self):
257258
self.assert_results(dict_test_fromkeys, (1, 2, 3, 4))
258259
self.assert_results(dict_test_fromkeys, [1, 2, 3, 4])

test/sot/test_07_unpack.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import unittest
2121

22-
from test_case_base import TestCaseBase
22+
from test_case_base import TestCaseBase, test_with_faster_guard
2323

2424
import paddle
2525

@@ -56,6 +56,7 @@ def test_unpack_tuple(self):
5656
def test_unpack_tensor(self):
5757
self.assert_results(unpack_tensor, paddle.to_tensor([2, 3]))
5858

59+
@test_with_faster_guard
5960
def test_unpack_ex_tuple(self):
6061
self.assert_results(unpack_ex_tuple, (1, 1, paddle.to_tensor(2)))
6162

test/sot/test_12_for_loop.py

+6
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def test_list(self):
165165
a = paddle.to_tensor(1)
166166
self.assert_results(for_list_1, a)
167167

168+
@test_with_faster_guard
168169
def test_list_with_fallback(self):
169170
a = paddle.to_tensor(1)
170171
self.assert_results(for_list_2, a)
@@ -194,6 +195,7 @@ def test_for_break(self):
194195
paddle_output = for_break(a, gener())
195196
self.assert_nest_match(sym_output, paddle_output)
196197

198+
@test_with_faster_guard
197199
def test_for_continue(self):
198200
a = paddle.to_tensor(1)
199201
sym_output = symbolic_translate(for_continue)(a, gener())
@@ -205,6 +207,7 @@ def test_for_continue(self):
205207
# a = [1, 2, 3]
206208
# self.assert_results(for_enumerate_var_with_nested_range, a)
207209

210+
@test_with_faster_guard
208211
def test_create_var_in_loop(self):
209212
x = paddle.to_tensor(1, dtype="float32")
210213
a = [1, 2, 3]
@@ -217,6 +220,7 @@ def test_create_var_in_loop(self):
217220
def test_create_var_in_loop_with_same_name_as_global(self):
218221
self.assert_results(for_tmp_var_with_same_name_as_global_var)
219222

223+
@test_with_faster_guard
220224
def test_for_without_zero_iter(self):
221225
self_res_dict = {}
222226
output = paddle.to_tensor(2)
@@ -253,6 +257,7 @@ def for_enumerate_cache(func_list, x):
253257

254258

255259
class TestEnumerateCache(TestCaseBase):
260+
@test_with_faster_guard
256261
def test_run(self):
257262
func_list = [
258263
paddle.nn.Linear(10, 10),
@@ -294,6 +299,7 @@ class TestUndefinedVarInRiskyCodes(TestCaseBase):
294299
def test_undefined_var_case_0(self):
295300
self.assert_results(undefined_var_case_0)
296301

302+
@test_with_faster_guard
297303
def test_undefined_var_case_1(self):
298304
self.assert_results(undefined_var_case_1)
299305

test/sot/test_14_operators.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import operator
1616
import unittest
1717

18-
from test_case_base import TestCaseBase
18+
from test_case_base import TestCaseBase, test_with_faster_guard
1919

2020
import paddle
2121

@@ -328,6 +328,7 @@ def test_simple(self):
328328
self.assert_results(inplace_or, b, g)
329329
self.assert_results(inplace_xor, b, g)
330330

331+
@test_with_faster_guard
331332
def test_operator_simple(self):
332333
self.assert_results(operator_add, 1, paddle.to_tensor(2))
333334
self.assert_results(operator_mul, 1, paddle.to_tensor(2))

test/sot/test_15_slice.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import unittest
2020

21-
from test_case_base import TestCaseBase
21+
from test_case_base import TestCaseBase, test_with_faster_guard
2222

2323
import paddle
2424
from paddle.jit.sot.psdb import check_no_breakgraph
@@ -55,6 +55,7 @@ def tensor_subscript_tensor(x: paddle.Tensor):
5555

5656

5757
class TestSlice(TestCaseBase):
58+
@test_with_faster_guard
5859
def test_simple(self):
5960
x = list(range(10))
6061
y = paddle.arange(10)
@@ -83,6 +84,7 @@ def layer_list_slice(layer, x):
8384

8485

8586
class TestLayerList(TestCaseBase):
87+
@test_with_faster_guard
8688
def test_layer_list_slice(self):
8789
layer = MyLayer()
8890
x = paddle.randn([5, 10])
@@ -127,6 +129,7 @@ def forward(self, x):
127129

128130

129131
class TestLayerListSlice(TestCaseBase):
132+
@test_with_faster_guard
130133
def test_layer_list_slice(self):
131134
x = paddle.randn([2, 5])
132135
net = LayerListNet()

test/sot/test_18_tensor_method.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import unittest
1616

17-
from test_case_base import TestCaseBase
17+
from test_case_base import TestCaseBase, test_with_faster_guard
1818

1919
import paddle
2020

@@ -83,6 +83,7 @@ def test_tensor_method_passed_by_user(self):
8383
y = paddle.rand([42])
8484
self.assert_results(tensor_method_passed_by_user, x, y.add)
8585

86+
@test_with_faster_guard
8687
def test_tensor_method_property(self):
8788
x = paddle.rand([42, 24], dtype='float64')
8889
y = paddle.rand([42, 24], dtype='float32')

test/sot/test_builtin_bool.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from test_case_base import (
1818
TestCaseBase,
1919
test_instruction_translator_cache_context,
20+
test_with_faster_guard,
2021
)
2122

2223
import paddle
@@ -107,6 +108,7 @@ def test_object_disallow_breakgraph(self):
107108
call_bool_by_operator_truth_no_breakgraph, layer
108109
)
109110

111+
@test_with_faster_guard
110112
def test_object_allow_breakgraph(self):
111113
with test_instruction_translator_cache_context():
112114
obj = TestObjectWithLen([1, 2, 3])

test/sot/test_builtin_dispatch.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from test_case_base import (
2323
TestCaseBase,
2424
test_instruction_translator_cache_context,
25+
test_with_faster_guard,
2526
)
2627

2728
import paddle
@@ -135,9 +136,11 @@ def test_log(x: int):
135136

136137

137138
class TestBuiltinDispatch(TestCaseBase):
139+
@test_with_faster_guard
138140
def test_dispatch_len(self):
139141
self.assert_results(dispatch_len, paddle.to_tensor([1, 2, 3]))
140142

143+
@test_with_faster_guard
141144
def test_dispatch_bool(self):
142145
self.assert_results(dispatch_bool, paddle.to_tensor([1, 2, 3]))
143146

@@ -177,6 +180,7 @@ def test_not_dispatch_tensor_floor(self):
177180
def test_dispatch_float_floor(self):
178181
self.assert_results(dispatch_floor, 1.2)
179182

183+
@test_with_faster_guard
180184
def test_dispatch_sum(self):
181185
self.assert_results(test_sum_tuple, 1, 1)
182186
self.assert_results(test_sum_tuple, paddle.to_tensor(1), 1)

test/sot/test_builtin_map.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import unittest
1818
from typing import TYPE_CHECKING
1919

20-
from test_case_base import TestCaseBase
20+
from test_case_base import TestCaseBase, test_with_faster_guard
2121

2222
from paddle.jit import sot
2323
from paddle.jit.sot.psdb import check_no_breakgraph
@@ -99,12 +99,14 @@ def test_map_for_loop(x: list):
9999

100100

101101
class TestMap(TestCaseBase):
102+
@test_with_faster_guard
102103
def test_map(self):
103104
self.assert_results(test_map_list, [1, 2, 3, 4])
104105
self.assert_results(test_map_tuple, (1, 2, 3, 4))
105106
self.assert_results(test_map_range, range(5))
106107
self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3})
107108

109+
@test_with_faster_guard
108110
def test_map_comprehension(self):
109111
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
110112
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
@@ -117,9 +119,11 @@ def test_map_with_breakgraph(self):
117119
with strict_mode_guard(False):
118120
self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4])
119121

122+
@test_with_faster_guard
120123
def test_map_unpack(self):
121124
self.assert_results(test_map_unpack, [1, 2, 3, 4])
122125

126+
@test_with_faster_guard
123127
def test_map_for_loop(self):
124128
self.assert_results(test_map_for_loop, [7, 8, 9, 10])
125129

test/sot/test_builtin_zip.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import unittest
1616

17-
from test_case_base import TestCaseBase
17+
from test_case_base import TestCaseBase, test_with_faster_guard
1818

1919
import paddle
2020
from paddle.jit.sot import psdb, symbolic_translate
@@ -77,6 +77,7 @@ def test_zip_8(iter_1, iter_2):
7777

7878

7979
class TestZip(TestCaseBase):
80+
@test_with_faster_guard
8081
def test_simple_cases(self):
8182
x = 8
8283
y = 5
@@ -92,10 +93,12 @@ def test_simple_cases(self):
9293
self.assert_results(test_zip_6, ty)
9394
self.assert_results(test_zip_7, layer_list, paddle.randn((10,)))
9495

96+
@test_with_faster_guard
9597
@min_graph_size_guard(0)
9698
def test_reconstruct(self):
9799
self.assert_results(test_zip_8, [1, 2, 3], [4, 5, 6])
98100

101+
@test_with_faster_guard
99102
@strict_mode_guard(False)
100103
@min_graph_size_guard(0)
101104
def test_zip_user_defined_iter(self):

test/sot/test_dtype.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_dtype_reconstruct(self):
7474

7575

7676
class TestDtypeGuard(TestCaseBase):
77+
@test_with_faster_guard
7778
def test_dtype_guard(self):
7879
dtype_map = {paddle.float32: paddle.float64}
7980
x = paddle.ones([2, 3], dtype="float32")

0 commit comments

Comments
 (0)