Skip to content

[SOT][Faster Guard] Implement more make_faster_guard #72272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 17, 2025
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void BindGuardTree(pybind11::module *m) {
*m, "GuardTree", R"DOC(GuardTree Class.)DOC")
.def(py::init<
const std::vector<std::vector<std::shared_ptr<GuardNode>>> &>(),
py::arg("guard_nodes_list"))
py::arg("guard_chain_list"))
.def(
"lookup",
[](GuardTree &self, py::object frame) {
Expand Down Expand Up @@ -180,7 +180,7 @@ void BindGuardTree(pybind11::module *m) {
return self.lookup(reinterpret_cast<FrameProxy *>(frame.ptr()));
},
py::arg("frame"))
.def("stringify", &GuardNode::stringify);
.def("stringify", &GuardNode::stringify, py::arg("indent") = 0);

py::class_<ExprNode, std::shared_ptr<ExprNode>>(
*m, "ExprNode", R"DOC(ExprNode Class.)DOC")
Expand All @@ -190,7 +190,7 @@ void BindGuardTree(pybind11::module *m) {
return self.eval(reinterpret_cast<FrameProxy *>(frame.ptr()));
},
py::arg("frame"))
.def("stringify", &ExprNode::stringify);
.def("stringify", &ExprNode::stringify, py::arg("indent") = 0);

py::class_<ConstantExprNode, ExprNode, std::shared_ptr<ConstantExprNode>>(
*m, "ConstantExprNode", R"DOC(ConstantExprNode Class.)DOC")
Expand Down
37 changes: 28 additions & 9 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
#define Py_IsNone(x) ((x) == Py_None)
#endif

#define HANDLE_NULL_VALUE(value) \
if ((value) == NULL) { \
PyErr_Clear(); \
return false; \
}

static inline bool PyObject_Equal(PyObject* a, PyObject* b) {
if (a == b) {
return true;
Expand Down Expand Up @@ -84,6 +90,7 @@ bool TypeMatchGuard::check(PyObject* value) {
bool IdMatchGuard::check(PyObject* value) { return value == expected_; }

bool ValueMatchGuard::check(PyObject* value) {
HANDLE_NULL_VALUE(value);
return PyObject_Equal(value, expected_value_);
}

Expand All @@ -107,6 +114,7 @@ bool DtypeMatchGuard::check(PyObject* value) {
}

bool ShapeMatchGuard::check(PyObject* value) {
HANDLE_NULL_VALUE(value);
auto tensor = GetTensorFromPyObject(value);
if (!tensor) {
return false;
Expand Down Expand Up @@ -194,10 +202,12 @@ bool WeakRefMatchGuard::check(PyObject* value) {
}

PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; }
std::string ConstantExprNode::stringify() { return py::str(value_ptr_); }
std::string ConstantExprNode::stringify(int indent) {
return py::str(value_ptr_);
}

PyObject* ExternVarExprNode::eval(FrameProxy* frame) { return value_ptr_; }
std::string ExternVarExprNode::stringify() { return var_name_; }
std::string ExternVarExprNode::stringify(int indent) { return var_name_; }

PyObject* LocalVarExprNode::eval(FrameProxy* frame) {
#if PY_3_13_PLUS
Expand All @@ -208,7 +218,7 @@ PyObject* LocalVarExprNode::eval(FrameProxy* frame) {
return PyDict_GetItemString(frame->f_locals, var_name_.c_str());
#endif
}
std::string LocalVarExprNode::stringify() {
std::string LocalVarExprNode::stringify(int indent) {
return "locals[" + var_name_ + "]";
}

Expand All @@ -219,15 +229,15 @@ PyObject* GlobalVarExprNode::eval(FrameProxy* frame) {
return PyDict_GetItemString(frame->f_globals, var_name_.c_str());
#endif
}
std::string GlobalVarExprNode::stringify() {
std::string GlobalVarExprNode::stringify(int indent) {
return "globals[" + var_name_ + "]";
}

PyObject* AttributeExprNode::eval(FrameProxy* frame) {
PyObject* var = var_expr_->eval(frame);
return PyObject_GetAttrString(var, attr_name_.c_str());
}
std::string AttributeExprNode::stringify() {
std::string AttributeExprNode::stringify(int indent) {
std::stringstream ss;
ss << var_expr_->stringify() << "." << attr_name_;
return ss.str();
Expand All @@ -238,7 +248,7 @@ PyObject* ItemExprNode::eval(FrameProxy* frame) {
PyObject* key = key_expr_->eval(frame);
return PyObject_GetItem(var, key);
}
std::string ItemExprNode::stringify() {
std::string ItemExprNode::stringify(int indent) {
std::stringstream ss;
ss << var_expr_->stringify() << "[" << key_expr_->stringify() << "]";
return ss.str();
Expand All @@ -261,10 +271,19 @@ std::optional<int> GuardNode::lookup(FrameProxy* frame) {
}
return std::nullopt;
}
std::string GuardNode::stringify() {
std::string GuardNode::stringify(int indent) {
std::stringstream ss;
ss << guard->get_guard_name();
// TODO(zrr1999): support multiple exprs
auto expr = exprs.back();
ss << std::string(indent, ' ') << guard->get_guard_name();
ss << "(" << exprs.back()->stringify() << ")";
if (!next_guard_nodes.empty()) {
ss << " |" << std::endl;
for (auto& next_guard_node : next_guard_nodes) {
ss << std::string(indent + 2, ' ');
ss << next_guard_node->stringify(indent + 2) << std::endl;
}
}
return ss.str();
}

Expand Down Expand Up @@ -295,7 +314,7 @@ std::string GuardTree::stringify() {
std::stringstream ss;
for (size_t i = 0; i < guard_nodes_.size(); ++i) {
if (i > 0) {
ss << " and ";
ss << std::endl << "and" << std::endl;
}
ss << guard_nodes_[i]->stringify();
}
Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class DummyGuard : public GuardBase {
class GuardTreeNode {
public:
virtual ~GuardTreeNode() = default;
virtual std::string stringify() = 0;
virtual std::string stringify(int indent = 0) = 0;
};

class AttributeExprNode;
Expand All @@ -295,7 +295,7 @@ class ConstantExprNode : public ExprNode {
}
~ConstantExprNode() { Py_DECREF(value_ptr_); }
PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
PyObject* value_ptr_;
Expand All @@ -310,7 +310,7 @@ class ExternVarExprNode : public ExprNode {

~ExternVarExprNode() { Py_DECREF(value_ptr_); }
PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
PyObject* value_ptr_;
Expand All @@ -323,7 +323,7 @@ class LocalVarExprNode : public ExprNode {
: var_name_(var_name) {}

PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
std::string var_name_;
Expand All @@ -334,7 +334,7 @@ class GlobalVarExprNode : public ExprNode {
: var_name_(var_name) {}

PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
std::string var_name_;
Expand All @@ -346,7 +346,7 @@ class AttributeExprNode : public ExprNode {
: var_expr_(var_expr), attr_name_(attr_name) {}

PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
std::shared_ptr<ExprNode> var_expr_;
Expand All @@ -359,7 +359,7 @@ class ItemExprNode : public ExprNode {
: var_expr_(var_expr), key_expr_(key_expr) {}

PyObject* eval(FrameProxy* frame) override;
std::string stringify() override;
std::string stringify(int indent = 0) override;

private:
std::shared_ptr<ExprNode> var_expr_;
Expand All @@ -381,17 +381,17 @@ class GuardNode : public GuardTreeNode {
exprs(exprs),
next_guard_nodes(next_guard_nodes),
return_cache_index(return_cache_index) {}
std::string stringify() override;
virtual ~GuardNode() = default;
std::string stringify(int indent = 0) override;
std::optional<int> lookup(FrameProxy* frame);
};

class GuardTree {
public:
GuardTree(const std::vector<std::vector<std::shared_ptr<GuardNode>>>&
guard_nodes_list) {
for (size_t index = 0; index < guard_nodes_list.size(); ++index) {
add_guard_chain(guard_nodes_list[index]);
guard_chain_list) {
for (size_t index = 0; index < guard_chain_list.size(); ++index) {
add_guard_chain(guard_chain_list[index]);
}
}
void add_guard_chain(
Expand Down
19 changes: 15 additions & 4 deletions python/paddle/jit/sot/opcode_translator/executor/executor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,23 @@ def lookup(

cache_index = None
if enable_strict_guard or enable_guard_tree:
log(4, f"[Cache] Guard tree is `{guard_tree.stringify()}`")
log(4, f"[Cache] Guard tree: \n{guard_tree.stringify()}")
cache_index = guard_tree.lookup(frame)

if not enable_strict_guard and cache_index is not None:
# TODO(zrr1999): add a mapping between custom_code and cache_index
return guarded_fns[cache_index][0]
if not enable_strict_guard and enable_guard_tree:
if cache_index is not None:
# TODO(zrr1999): add a mapping between custom_code and cache_index
return guarded_fns[cache_index][0]
else:
log(2, "[Cache]: all guards missed (guard tree mode)\n")
new_custom_code, guard_fn, guard_chain = self.translate(
frame, **kwargs
)
if guard_fn is not None:
assert guard_chain is not None
guarded_fns.append((new_custom_code, guard_fn))
guard_tree.add_guard_chain(guard_chain)
return new_custom_code

for index, (custom_code, guard_fn) in enumerate(guarded_fns):
if enable_strict_guard:
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,27 @@ def object_equal_stringified_guard(self) -> list[StringifiedExpression]:
]


@check_faster_guard
def object_equal_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
expr_node = self.tracker.guard_tree_expr_node()

weak_ref_obj = self.get_py_value()
if support_weak_ref(weak_ref_obj):
weak_ref_obj = weakref.ref(self.get_py_value())
return [
paddle.framework.core.GuardNode(
paddle.framework.core.WeakRefMatchGuard(self.get_py_value()),
[expr_node],
)
]
return [
paddle.framework.core.GuardNode(
paddle.framework.core.ValueMatchGuard(weak_ref_obj),
[expr_node],
)
]


def stringify_pyobject(obj: object) -> tuple[str, dict[str, Any]]:
if isinstance(obj, paddle.core.VarDesc.VarType):
return f"paddle.core.VarDesc.VarType({obj.value})", {"paddle": paddle}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,11 @@ def __hash__(self):

@check_faster_guard
def make_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
expr = self.tracker.guard_tree_expr_node()
expr_node = self.tracker.guard_tree_expr_node()
return [
paddle.framework.core.GuardNode(
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
[expr],
[expr_node],
)
]

Expand Down
Loading
Loading