Skip to content

Commit d81a381

Browse files
authored
[SOT][Faster Guard][3.13] implement faster lookup (#71994)
1 parent 0a85d41 commit d81a381

File tree

15 files changed

+328
-71
lines changed

15 files changed

+328
-71
lines changed

paddle/fluid/pybind/jit.cc

+19-3
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ void BindGuard(pybind11::module *m) {
125125
py::class_<WeakRefMatchGuard, GuardBase, std::shared_ptr<WeakRefMatchGuard>>(
126126
*m, "WeakRefMatchGuard", R"DOC(WeakRefMatchGuard Class.)DOC")
127127
.def(py::init<const py::object &>(), py::arg("func"));
128+
py::class_<DummyGuard, GuardBase, std::shared_ptr<DummyGuard>>(
129+
*m, "DummyGuard", R"DOC(DummyGuard Class.)DOC")
130+
.def(py::init<>());
128131

129132
m->def(
130133
"merge_guard",
@@ -147,16 +150,23 @@ void BindGuardTree(pybind11::module *m) {
147150
[](GuardTree &self, py::object frame) {
148151
return self.lookup(reinterpret_cast<FrameProxy *>(frame.ptr()));
149152
},
150-
py::arg("frame"));
153+
py::arg("frame"))
154+
.def(
155+
"add_guard_chain",
156+
[](GuardTree &self,
157+
const std::vector<std::shared_ptr<GuardNode>> &guard_chain) {
158+
self.add_guard_chain(guard_chain);
159+
},
160+
py::arg("guard_chain"));
151161

152162
py::class_<GuardNode, std::shared_ptr<GuardNode>>(
153163
*m, "GuardNode", R"DOC(GuardNode Class.)DOC")
154164
.def(py::init<const std::shared_ptr<GuardBase> &,
155-
const std::shared_ptr<ExprNode> &,
165+
const std::vector<std::shared_ptr<ExprNode>> &,
156166
const std::vector<std::shared_ptr<GuardNode>> &,
157167
const std::optional<int> &>(),
158168
py::arg("guard"),
159-
py::arg("expr"),
169+
py::arg("exprs"),
160170
py::arg("next_guard_nodes") = py::list(),
161171
py::arg("return_cache_index") = py::none())
162172
.def_property(
@@ -183,6 +193,12 @@ void BindGuardTree(pybind11::module *m) {
183193
*m, "ConstantExprNode", R"DOC(ConstantExprNode Class.)DOC")
184194
.def(py::init<const py::object &>(), py::arg("value_ptr"));
185195

196+
py::class_<ExternVarExprNode, ExprNode, std::shared_ptr<ExternVarExprNode>>(
197+
*m, "ExternVarExprNode", R"DOC(ExternVarExprNode Class.)DOC")
198+
.def(py::init<const std::string &, const py::object &>(),
199+
py::arg("var_name"),
200+
py::arg("value_ptr"));
201+
186202
py::class_<LocalVarExprNode, ExprNode, std::shared_ptr<LocalVarExprNode>>(
187203
*m, "LocalVarExprNode", R"DOC(LocalVarExprNode Class.)DOC")
188204
.def(py::init<const std::string &>(), py::arg("var_name"));

paddle/fluid/pybind/sot/guards.cc

+4
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ bool WeakRefMatchGuard::check(PyObject* value) {
195195

196196
PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; }
197197

198+
PyObject* ExternVarExprNode::eval(FrameProxy* frame) { return value_ptr_; }
199+
198200
PyObject* LocalVarExprNode::eval(FrameProxy* frame) {
199201
#if PY_3_13_PLUS
200202
return PyDict_GetItemString(frame->locals, var_name_.c_str());
@@ -222,6 +224,8 @@ PyObject* ItemExprNode::eval(FrameProxy* frame) {
222224
}
223225

224226
std::optional<int> GuardNode::lookup(FrameProxy* frame) {
227+
// TODO(zrr1999): support multiple exprs
228+
auto expr = exprs.back();
225229
auto value = expr->eval(frame);
226230
if (guard->check(value)) {
227231
if (return_cache_index.has_value()) {

paddle/fluid/pybind/sot/guards.h

+37-9
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ class WeakRefMatchGuard : public GuardBase {
250250
PyObject* expected_;
251251
};
252252

253+
class DummyGuard : public GuardBase {
254+
public:
255+
bool check(PyObject* value) override { return true; }
256+
};
257+
253258
class GuardTreeNode {};
254259

255260
class AttributeExprNode;
@@ -272,6 +277,21 @@ class ConstantExprNode : public ExprNode {
272277
private:
273278
PyObject* value_ptr_;
274279
};
280+
class ExternVarExprNode : public ExprNode {
281+
public:
282+
explicit ExternVarExprNode(const std::string& var_name,
283+
const py::object& value_obj)
284+
: value_ptr_(value_obj.ptr()), var_name_(var_name) {
285+
Py_INCREF(value_ptr_);
286+
}
287+
288+
~ExternVarExprNode() { Py_DECREF(value_ptr_); }
289+
PyObject* eval(FrameProxy* frame);
290+
291+
private:
292+
PyObject* value_ptr_;
293+
std::string var_name_;
294+
};
275295

276296
class LocalVarExprNode : public ExprNode {
277297
public:
@@ -321,16 +341,16 @@ class ItemExprNode : public ExprNode {
321341
class GuardNode : public GuardTreeNode {
322342
public:
323343
std::shared_ptr<GuardBase> guard;
324-
std::shared_ptr<ExprNode> expr;
344+
std::vector<std::shared_ptr<ExprNode>> exprs;
325345
std::vector<std::shared_ptr<GuardNode>> next_guard_nodes;
326346
// return_cache_index is used to record the index of the guard list
327347
std::optional<int> return_cache_index;
328348
GuardNode(std::shared_ptr<GuardBase> guard,
329-
std::shared_ptr<ExprNode> expr,
349+
std::vector<std::shared_ptr<ExprNode>> exprs,
330350
std::vector<std::shared_ptr<GuardNode>> next_guard_nodes,
331351
std::optional<int> return_cache_index)
332352
: guard(guard),
333-
expr(expr),
353+
exprs(exprs),
334354
next_guard_nodes(next_guard_nodes),
335355
return_cache_index(return_cache_index) {}
336356

@@ -342,13 +362,21 @@ class GuardTree {
342362
GuardTree(const std::vector<std::vector<std::shared_ptr<GuardNode>>>&
343363
guard_nodes_list) {
344364
for (size_t index = 0; index < guard_nodes_list.size(); ++index) {
345-
const auto& guard_nodes = guard_nodes_list[index];
346-
for (size_t i = 1; i < guard_nodes.size(); ++i) {
347-
guard_nodes[i - 1]->next_guard_nodes.push_back(guard_nodes[i]);
348-
}
349-
guard_nodes.back()->return_cache_index = index;
350-
guard_nodes_.push_back(guard_nodes.front());
365+
add_guard_chain(guard_nodes_list[index]);
366+
}
367+
}
368+
void add_guard_chain(
369+
const std::vector<std::shared_ptr<GuardNode>>& guard_chain) {
370+
if (guard_chain.empty()) {
371+
// TODO(zrr1999): empty guard nodes means that some
372+
// tracker.make_faster_guard is not implemented.
373+
return;
374+
}
375+
for (size_t i = 1; i < guard_chain.size(); ++i) {
376+
guard_chain[i - 1]->next_guard_nodes.push_back(guard_chain[i]);
351377
}
378+
guard_chain.back()->return_cache_index = guard_nodes_.size();
379+
guard_nodes_.push_back(guard_chain.front());
352380
}
353381

354382
std::optional<int> lookup(FrameProxy* frame);

python/paddle/jit/sot/opcode_translator/executor/executor_cache.py

+58-17
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import traceback
1919
from typing import TYPE_CHECKING, List, Tuple
2020

21+
import paddle
2122
from paddle.base.dygraph.base import sot_simulation_mode_guard
2223

2324
from ...profiler import EventGuard, event_register
@@ -47,6 +48,8 @@
4748

4849
GuardedFunction = Tuple[CustomCode, Guard]
4950
GuardedFunctions = List[GuardedFunction]
51+
GuardChain = List[paddle.framework.core.GuardNode]
52+
GuardChainList = List[GuardChain]
5053

5154
dummy_guard: Guard = lambda frame: True
5255
dummy_guard.expr = "lambda frame: True"
@@ -66,7 +69,9 @@ class OpcodeExecutorCache(metaclass=Singleton):
6669
"""
6770

6871
MAX_CACHE_SIZE = 20
69-
cache: dict[types.CodeType, GuardedFunctions]
72+
cache: dict[
73+
types.CodeType, tuple[GuardedFunctions, paddle.framework.core.GuardTree]
74+
]
7075
translate_count: int
7176
code_symbolic_inputs: dict[types.CodeType, dict[str, None | dict[int, int]]]
7277

@@ -105,16 +110,25 @@ def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
105110
code: types.CodeType = frame.f_code
106111
if code not in self.cache:
107112
log(2, f"[Cache]: Firstly call {code}\n")
108-
new_custom_code, guard_fn = self.translate(frame, **kwargs)
113+
new_custom_code, guard_fn, guard_chain = self.translate(
114+
frame, **kwargs
115+
)
109116
assert guard_fn is not None
110-
self.cache[code] = [(new_custom_code, guard_fn)]
117+
assert guard_chain is not None
118+
self.cache[code] = [
119+
(new_custom_code, guard_fn)
120+
], paddle.framework.core.GuardTree([guard_chain])
111121
return new_custom_code
112-
guarded_fns = self.cache[code]
113-
return self.lookup(frame, guarded_fns, **kwargs)
122+
guarded_fns, guard_tree = self.cache[code]
123+
return self.lookup(frame, guarded_fns, guard_tree, **kwargs)
114124

115125
@event_register("lookup")
116126
def lookup(
117-
self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs
127+
self,
128+
frame: types.FrameType,
129+
guarded_fns: GuardedFunctions,
130+
guard_tree: paddle.framework.core.GuardTree,
131+
**kwargs,
118132
) -> CustomCode:
119133
"""
120134
Looks up the cache for a matching code object and returns a custom code object if a matching guard function is found, otherwise None.
@@ -132,8 +146,17 @@ def lookup(
132146
return CustomCode(None, False)
133147

134148
enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK.get()
149+
enable_guard_tree = ENV_SOT_ENABLE_GUARD_TREE.get()
150+
151+
cache_index = None
152+
if enable_strict_guard or enable_guard_tree:
153+
cache_index = guard_tree.lookup(frame)
135154

136-
for custom_code, guard_fn in guarded_fns:
155+
if not enable_strict_guard and cache_index is not None:
156+
# TODO(zrr1999): add a mapping between custom_code and cache_index
157+
return guarded_fns[cache_index][0]
158+
159+
for index, (custom_code, guard_fn) in enumerate(guarded_fns):
137160
if enable_strict_guard:
138161
mirror_guard_error = None
139162
try:
@@ -157,9 +180,12 @@ def lookup(
157180
2,
158181
f"[Cache] Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n",
159182
)
183+
# TODO(zrr1999): cache_index should be equal to index when enable_strict_guard.
184+
# assert (
185+
# cache_index is None or index == cache_index
186+
# ), f"cache_index({cache_index}) is not equal to index({index})"
160187
return custom_code
161-
elif not ENV_SOT_ENABLE_GUARD_TREE.get():
162-
# TODO(zrr1999): remove condition after faster guard tree support error analysis
188+
else:
163189
log_do(
164190
4,
165191
self.analyse_guard_global_object(guard_fn),
@@ -192,9 +218,11 @@ def lookup(
192218
)
193219

194220
log(2, "[Cache]: all guards missed\n")
195-
new_custom_code, guard_fn = self.translate(frame, **kwargs)
221+
new_custom_code, guard_fn, guard_chain = self.translate(frame, **kwargs)
196222
if guard_fn is not None:
223+
assert guard_chain is not None
197224
guarded_fns.append((new_custom_code, guard_fn))
225+
guard_tree.add_guard_chain(guard_chain)
198226
return new_custom_code
199227

200228
def before_translate_hook(self, frame: types.FrameType):
@@ -203,7 +231,7 @@ def before_translate_hook(self, frame: types.FrameType):
203231

204232
def translate(
205233
self, frame: types.FrameType, **kwargs
206-
) -> tuple[CustomCode, Guard | None]:
234+
) -> tuple[CustomCode, Guard | None, GuardChain | None]:
207235
"""
208236
Translates the given frame's code object and returns the cache getter function and a guarded function for the translated code object.
209237
@@ -215,8 +243,10 @@ def translate(
215243
"""
216244
self.before_translate_hook(frame)
217245
self.translate_count += 1
218-
custom_new_code, guard_fn = start_translate(frame, **kwargs)
219-
return custom_new_code, guard_fn
246+
custom_new_code, guard_fn, guard_chain = start_translate(
247+
frame, **kwargs
248+
)
249+
return custom_new_code, guard_fn, guard_chain
220250

221251
def analyse_guard_global_object(self, guard_fn):
222252
def inner():
@@ -255,15 +285,15 @@ def inner():
255285
def start_translate(
256286
frame: types.FrameType,
257287
**kwargs,
258-
) -> tuple[CustomCode, Guard | None]:
288+
) -> tuple[CustomCode, Guard | None, GuardChain | None]:
259289
"""
260-
Starts the translation process for the given frame and returns the translated code object and its guard function, or None if translation fails.
290+
Starts the translation process for the given frame and returns the translated code object, its guard function and its guard tree node, or None if translation fails.
261291
262292
Args:
263293
frame: The frame to be translated.
264294
265295
Returns:
266-
tuple[CustomCode, Guard | None]: The translated code object and its guard function, or None if translation fails.
296+
tuple[CustomCode, Guard | None, GuardChain | None]: The translated code object, its guard function and its guard tree node, or None if translation fails.
267297
"""
268298
graph = FunctionGraph(frame.f_code, frame.f_globals, **kwargs)
269299
vframe = VirtualFrame.from_real_frame(frame, graph)
@@ -280,8 +310,10 @@ def start_translate(
280310
return (
281311
CustomCode(None, True),
282312
None,
313+
None,
283314
)
284-
return new_custom_code, guard_fn
315+
guard_chain = simulator.guard_chain
316+
return new_custom_code, guard_fn, guard_chain
285317
# TODO(0x45f): handle BreakGraphError to trigger fallback
286318
except BreakGraphError as e:
287319
raise RuntimeError(
@@ -299,9 +331,18 @@ def start_translate(
299331
f"Unsupported Frame is {frame.f_code}, error message is: \n"
300332
+ "".join(traceback.format_exception(type(e), e, e.__traceback__)),
301333
)
334+
335+
dummy_guard_chain = [
336+
# TODO(zrr1999): GuardNode should support zero-expr constructor
337+
paddle.framework.core.GuardNode(
338+
paddle.framework.core.DummyGuard(),
339+
[paddle.framework.core.ConstantExprNode(True)],
340+
)
341+
]
302342
return (
303343
CustomCode(None, e.disable_eval_frame),
304344
dummy_guard,
345+
dummy_guard_chain,
305346
)
306347
except Exception as e:
307348
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...utils import (
4444
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
4545
ENV_SOT_ENABLE_GUARD_TREE,
46+
ENV_SOT_ENABLE_STRICT_GUARD_CHECK,
4647
NameGenerator,
4748
SotUndefinedVar,
4849
inner_error_default_handler,
@@ -60,7 +61,7 @@
6061
SotExtraInfo,
6162
)
6263
from ..instruction_utils import get_instructions
63-
from .guard import Guard, StringifiedExpression, make_faster_guard, make_guard
64+
from .guard import Guard, StringifiedExpression, make_guard
6465
from .mutable_data import MutationDel, MutationNew, MutationSet
6566
from .pycode_generator import PyCodeGen
6667
from .side_effects import (
@@ -316,17 +317,30 @@ def collect(inp):
316317
)
317318

318319
@property
319-
@event_register("guard_fn")
320-
def guard_fn(self) -> Guard:
321-
if ENV_SOT_ENABLE_GUARD_TREE.get():
322-
guard_nodes: list[paddle.framework.core.GuardNode] = []
323-
with EventGuard("guard_fn: find vars and make faster guard"):
320+
@event_register("guard_chain")
321+
def guard_chain(self) -> list[paddle.framework.core.GuardNode]:
322+
enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK.get()
323+
enable_guard_tree = ENV_SOT_ENABLE_GUARD_TREE.get()
324+
325+
if not enable_strict_guard and not enable_guard_tree:
326+
return []
327+
guard_chain: list[paddle.framework.core.GuardNode] = []
328+
329+
with EventGuard("guard_fn: find vars and make faster guard"):
330+
try:
324331
for variable in find_traceable_vars(
325332
self.input_variables + list(self._global_guarded_variables)
326333
):
327-
guard_nodes.extend(variable.make_faster_guard())
328-
return make_faster_guard(guard_nodes)
334+
guard_chain.extend(variable.make_faster_guard())
335+
except NotImplementedError as e:
336+
log(2, f"[Guard] make faster guard nodes error: {e}\n")
337+
# TODO(zrr1999): empty list means that some tracker.make_faster_guard is not implemented.
338+
return []
339+
return guard_chain
329340

341+
@property
342+
@event_register("guard_fn")
343+
def guard_fn(self) -> Guard:
330344
with switch_symbol_registry():
331345
guards: list[StringifiedExpression] = []
332346
with EventGuard("guard_fn: find vars and make stringified guard"):

0 commit comments

Comments
 (0)