Skip to content

Commit ec59644

Browse files
authored
[SOT][Faster Guard] Implement more make_faster_guard (#72272)
1 parent e829589 commit ec59644

File tree

11 files changed

+218
-81
lines changed

11 files changed

+218
-81
lines changed

paddle/fluid/pybind/jit.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ void BindGuardTree(pybind11::module *m) {
144144
*m, "GuardTree", R"DOC(GuardTree Class.)DOC")
145145
.def(py::init<
146146
const std::vector<std::vector<std::shared_ptr<GuardNode>>> &>(),
147-
py::arg("guard_nodes_list"))
147+
py::arg("guard_chain_list"))
148148
.def(
149149
"lookup",
150150
[](GuardTree &self, py::object frame) {
@@ -180,7 +180,7 @@ void BindGuardTree(pybind11::module *m) {
180180
return self.lookup(reinterpret_cast<FrameProxy *>(frame.ptr()));
181181
},
182182
py::arg("frame"))
183-
.def("stringify", &GuardNode::stringify);
183+
.def("stringify", &GuardNode::stringify, py::arg("indent") = 0);
184184

185185
py::class_<ExprNode, std::shared_ptr<ExprNode>>(
186186
*m, "ExprNode", R"DOC(ExprNode Class.)DOC")
@@ -190,7 +190,7 @@ void BindGuardTree(pybind11::module *m) {
190190
return self.eval(reinterpret_cast<FrameProxy *>(frame.ptr()));
191191
},
192192
py::arg("frame"))
193-
.def("stringify", &ExprNode::stringify);
193+
.def("stringify", &ExprNode::stringify, py::arg("indent") = 0);
194194

195195
py::class_<ConstantExprNode, ExprNode, std::shared_ptr<ConstantExprNode>>(
196196
*m, "ConstantExprNode", R"DOC(ConstantExprNode Class.)DOC")

paddle/fluid/pybind/sot/guards.cc

+28-9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
3333
#define Py_IsNone(x) ((x) == Py_None)
3434
#endif
3535

36+
#define HANDLE_NULL_VALUE(value) \
37+
if ((value) == NULL) { \
38+
PyErr_Clear(); \
39+
return false; \
40+
}
41+
3642
static inline bool PyObject_Equal(PyObject* a, PyObject* b) {
3743
if (a == b) {
3844
return true;
@@ -84,6 +90,7 @@ bool TypeMatchGuard::check(PyObject* value) {
8490
bool IdMatchGuard::check(PyObject* value) { return value == expected_; }
8591

8692
bool ValueMatchGuard::check(PyObject* value) {
93+
HANDLE_NULL_VALUE(value);
8794
return PyObject_Equal(value, expected_value_);
8895
}
8996

@@ -107,6 +114,7 @@ bool DtypeMatchGuard::check(PyObject* value) {
107114
}
108115

109116
bool ShapeMatchGuard::check(PyObject* value) {
117+
HANDLE_NULL_VALUE(value);
110118
auto tensor = GetTensorFromPyObject(value);
111119
if (!tensor) {
112120
return false;
@@ -194,10 +202,12 @@ bool WeakRefMatchGuard::check(PyObject* value) {
194202
}
195203

196204
PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; }
197-
std::string ConstantExprNode::stringify() { return py::str(value_ptr_); }
205+
std::string ConstantExprNode::stringify(int indent) {
206+
return py::str(value_ptr_);
207+
}
198208

199209
PyObject* ExternVarExprNode::eval(FrameProxy* frame) { return value_ptr_; }
200-
std::string ExternVarExprNode::stringify() { return var_name_; }
210+
std::string ExternVarExprNode::stringify(int indent) { return var_name_; }
201211

202212
PyObject* LocalVarExprNode::eval(FrameProxy* frame) {
203213
#if PY_3_13_PLUS
@@ -208,7 +218,7 @@ PyObject* LocalVarExprNode::eval(FrameProxy* frame) {
208218
return PyDict_GetItemString(frame->f_locals, var_name_.c_str());
209219
#endif
210220
}
211-
std::string LocalVarExprNode::stringify() {
221+
std::string LocalVarExprNode::stringify(int indent) {
212222
return "locals[" + var_name_ + "]";
213223
}
214224

@@ -219,15 +229,15 @@ PyObject* GlobalVarExprNode::eval(FrameProxy* frame) {
219229
return PyDict_GetItemString(frame->f_globals, var_name_.c_str());
220230
#endif
221231
}
222-
std::string GlobalVarExprNode::stringify() {
232+
std::string GlobalVarExprNode::stringify(int indent) {
223233
return "globals[" + var_name_ + "]";
224234
}
225235

226236
PyObject* AttributeExprNode::eval(FrameProxy* frame) {
227237
PyObject* var = var_expr_->eval(frame);
228238
return PyObject_GetAttrString(var, attr_name_.c_str());
229239
}
230-
std::string AttributeExprNode::stringify() {
240+
std::string AttributeExprNode::stringify(int indent) {
231241
std::stringstream ss;
232242
ss << var_expr_->stringify() << "." << attr_name_;
233243
return ss.str();
@@ -238,7 +248,7 @@ PyObject* ItemExprNode::eval(FrameProxy* frame) {
238248
PyObject* key = key_expr_->eval(frame);
239249
return PyObject_GetItem(var, key);
240250
}
241-
std::string ItemExprNode::stringify() {
251+
std::string ItemExprNode::stringify(int indent) {
242252
std::stringstream ss;
243253
ss << var_expr_->stringify() << "[" << key_expr_->stringify() << "]";
244254
return ss.str();
@@ -261,10 +271,19 @@ std::optional<int> GuardNode::lookup(FrameProxy* frame) {
261271
}
262272
return std::nullopt;
263273
}
264-
std::string GuardNode::stringify() {
274+
std::string GuardNode::stringify(int indent) {
265275
std::stringstream ss;
266-
ss << guard->get_guard_name();
276+
// TODO(zrr1999): support multiple exprs
277+
auto expr = exprs.back();
278+
ss << std::string(indent, ' ') << guard->get_guard_name();
267279
ss << "(" << exprs.back()->stringify() << ")";
280+
if (!next_guard_nodes.empty()) {
281+
ss << " |" << std::endl;
282+
for (auto& next_guard_node : next_guard_nodes) {
283+
ss << std::string(indent + 2, ' ');
284+
ss << next_guard_node->stringify(indent + 2) << std::endl;
285+
}
286+
}
268287
return ss.str();
269288
}
270289

@@ -295,7 +314,7 @@ std::string GuardTree::stringify() {
295314
std::stringstream ss;
296315
for (size_t i = 0; i < guard_nodes_.size(); ++i) {
297316
if (i > 0) {
298-
ss << " and ";
317+
ss << std::endl << "and" << std::endl;
299318
}
300319
ss << guard_nodes_[i]->stringify();
301320
}

paddle/fluid/pybind/sot/guards.h

+11-11
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ class DummyGuard : public GuardBase {
275275
class GuardTreeNode {
276276
public:
277277
virtual ~GuardTreeNode() = default;
278-
virtual std::string stringify() = 0;
278+
virtual std::string stringify(int indent = 0) = 0;
279279
};
280280

281281
class AttributeExprNode;
@@ -295,7 +295,7 @@ class ConstantExprNode : public ExprNode {
295295
}
296296
~ConstantExprNode() { Py_DECREF(value_ptr_); }
297297
PyObject* eval(FrameProxy* frame) override;
298-
std::string stringify() override;
298+
std::string stringify(int indent = 0) override;
299299

300300
private:
301301
PyObject* value_ptr_;
@@ -310,7 +310,7 @@ class ExternVarExprNode : public ExprNode {
310310

311311
~ExternVarExprNode() { Py_DECREF(value_ptr_); }
312312
PyObject* eval(FrameProxy* frame) override;
313-
std::string stringify() override;
313+
std::string stringify(int indent = 0) override;
314314

315315
private:
316316
PyObject* value_ptr_;
@@ -323,7 +323,7 @@ class LocalVarExprNode : public ExprNode {
323323
: var_name_(var_name) {}
324324

325325
PyObject* eval(FrameProxy* frame) override;
326-
std::string stringify() override;
326+
std::string stringify(int indent = 0) override;
327327

328328
private:
329329
std::string var_name_;
@@ -334,7 +334,7 @@ class GlobalVarExprNode : public ExprNode {
334334
: var_name_(var_name) {}
335335

336336
PyObject* eval(FrameProxy* frame) override;
337-
std::string stringify() override;
337+
std::string stringify(int indent = 0) override;
338338

339339
private:
340340
std::string var_name_;
@@ -346,7 +346,7 @@ class AttributeExprNode : public ExprNode {
346346
: var_expr_(var_expr), attr_name_(attr_name) {}
347347

348348
PyObject* eval(FrameProxy* frame) override;
349-
std::string stringify() override;
349+
std::string stringify(int indent = 0) override;
350350

351351
private:
352352
std::shared_ptr<ExprNode> var_expr_;
@@ -359,7 +359,7 @@ class ItemExprNode : public ExprNode {
359359
: var_expr_(var_expr), key_expr_(key_expr) {}
360360

361361
PyObject* eval(FrameProxy* frame) override;
362-
std::string stringify() override;
362+
std::string stringify(int indent = 0) override;
363363

364364
private:
365365
std::shared_ptr<ExprNode> var_expr_;
@@ -381,17 +381,17 @@ class GuardNode : public GuardTreeNode {
381381
exprs(exprs),
382382
next_guard_nodes(next_guard_nodes),
383383
return_cache_index(return_cache_index) {}
384-
std::string stringify() override;
385384
virtual ~GuardNode() = default;
385+
std::string stringify(int indent = 0) override;
386386
std::optional<int> lookup(FrameProxy* frame);
387387
};
388388

389389
class GuardTree {
390390
public:
391391
GuardTree(const std::vector<std::vector<std::shared_ptr<GuardNode>>>&
392-
guard_nodes_list) {
393-
for (size_t index = 0; index < guard_nodes_list.size(); ++index) {
394-
add_guard_chain(guard_nodes_list[index]);
392+
guard_chain_list) {
393+
for (size_t index = 0; index < guard_chain_list.size(); ++index) {
394+
add_guard_chain(guard_chain_list[index]);
395395
}
396396
}
397397
void add_guard_chain(

python/paddle/jit/sot/opcode_translator/executor/executor_cache.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,23 @@ def lookup(
150150

151151
cache_index = None
152152
if enable_strict_guard or enable_guard_tree:
153-
log(4, f"[Cache] Guard tree is `{guard_tree.stringify()}`")
153+
log(4, f"[Cache] Guard tree: \n{guard_tree.stringify()}")
154154
cache_index = guard_tree.lookup(frame)
155155

156-
if not enable_strict_guard and cache_index is not None:
157-
# TODO(zrr1999): add a mapping between custom_code and cache_index
158-
return guarded_fns[cache_index][0]
156+
if not enable_strict_guard and enable_guard_tree:
157+
if cache_index is not None:
158+
# TODO(zrr1999): add a mapping between custom_code and cache_index
159+
return guarded_fns[cache_index][0]
160+
else:
161+
log(2, "[Cache]: all guards missed (guard tree mode)\n")
162+
new_custom_code, guard_fn, guard_chain = self.translate(
163+
frame, **kwargs
164+
)
165+
if guard_fn is not None:
166+
assert guard_chain is not None
167+
guarded_fns.append((new_custom_code, guard_fn))
168+
guard_tree.add_guard_chain(guard_chain)
169+
return new_custom_code
159170

160171
for index, (custom_code, guard_fn) in enumerate(guarded_fns):
161172
if enable_strict_guard:

python/paddle/jit/sot/opcode_translator/executor/guard.py

+21
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,27 @@ def object_equal_stringified_guard(self) -> list[StringifiedExpression]:
294294
]
295295

296296

297+
@check_faster_guard
298+
def object_equal_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
299+
expr_node = self.tracker.guard_tree_expr_node()
300+
301+
weak_ref_obj = self.get_py_value()
302+
if support_weak_ref(weak_ref_obj):
303+
weak_ref_obj = weakref.ref(self.get_py_value())
304+
return [
305+
paddle.framework.core.GuardNode(
306+
paddle.framework.core.WeakRefMatchGuard(self.get_py_value()),
307+
[expr_node],
308+
)
309+
]
310+
return [
311+
paddle.framework.core.GuardNode(
312+
paddle.framework.core.ValueMatchGuard(weak_ref_obj),
313+
[expr_node],
314+
)
315+
]
316+
317+
297318
def stringify_pyobject(obj: object) -> tuple[str, dict[str, Any]]:
298319
if isinstance(obj, paddle.core.VarDesc.VarType):
299320
return f"paddle.core.VarDesc.VarType({obj.value})", {"paddle": paddle}

python/paddle/jit/sot/opcode_translator/executor/variables/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,11 @@ def __hash__(self):
347347

348348
@check_faster_guard
349349
def make_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
350-
expr = self.tracker.guard_tree_expr_node()
350+
expr_node = self.tracker.guard_tree_expr_node()
351351
return [
352352
paddle.framework.core.GuardNode(
353353
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
354-
[expr],
354+
[expr_node],
355355
)
356356
]
357357

0 commit comments

Comments
 (0)