Skip to content

[SOT][Faster Guard] support TensorVariable Dist check #72327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ void BindGuard(pybind11::module *m) {
py::class_<DummyGuard, GuardBase, std::shared_ptr<DummyGuard>>(
*m, "DummyGuard", R"DOC(DummyGuard Class.)DOC")
.def(py::init<>());
py::class_<TensorDistMatchGuard,
GuardBase,
std::shared_ptr<TensorDistMatchGuard>>(
*m, "TensorDistMatchGuard", R"DOC(TensorDistMatchGuard Class.)DOC")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该叫 TensorDistMetaMatchGuard

.def(py::init<const py::object &>(), py::arg("tensor"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init 传的不是 dist info 吗?为什么叫 tensor?


m->def(
"merge_guard",
Expand Down
84 changes: 76 additions & 8 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/pybind/sot/guards.h"
#include <optional>
#include "paddle/fluid/eager/utils.h"
#include "paddle/phi/api/include/tensor.h"

#if SOT_IS_SUPPORTED
Expand All @@ -33,6 +34,12 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
#define Py_IsNone(x) ((x) == Py_None)
#endif

#define CheckTensorFromPyObject(value) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

宏的话用大写,不然这样看不出来是宏,仅仅是函数调用看不出来这里会产生控制流

auto tensor = GetTensorFromPyObject(value); \
if (!tensor) { \
return false; \
}

#define HANDLE_NULL_VALUE(value) \
if ((value) == NULL) { \
PyErr_Clear(); \
Expand Down Expand Up @@ -105,20 +112,14 @@ bool LengthMatchGuard::check(PyObject* value) {
}

bool DtypeMatchGuard::check(PyObject* value) {
auto tensor = GetTensorFromPyObject(value);
if (!tensor) {
return false;
}
CheckTensorFromPyObject(value);
auto dtype = tensor->type();
return phi::TransToProtoVarType(dtype) == expected_;
}

bool ShapeMatchGuard::check(PyObject* value) {
HANDLE_NULL_VALUE(value);
auto tensor = GetTensorFromPyObject(value);
if (!tensor) {
return false;
}
CheckTensorFromPyObject(value);
auto shape = tensor->shape();
if (shape.size() != expected_.size()) {
return false;
Expand Down Expand Up @@ -201,6 +202,73 @@ bool WeakRefMatchGuard::check(PyObject* value) {
#endif
}

phi::distributed::DistTensor* get_dist_tensor_from_py_object(PyObject* obj) {
if (paddle::pybind::PyCheckTensor(obj)) {
auto tensor = reinterpret_cast<paddle::pybind::TensorObject*>(obj)->tensor;
if (tensor.is_dist_tensor()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分逻辑为啥不复用 get_dist_tensor_from_tensor

return static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
}
}
return nullptr;
}

phi::distributed::DistTensor* get_dist_tensor_from_tensor(
const paddle::Tensor& tensor) {
if (tensor.is_dist_tensor()) {
return static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
}
return nullptr;
}

bool TensorDistMatchGuard::check(PyObject* value) {
if (value == NULL && expected_ == NULL) {
return true;
}
CheckTensorFromPyObject(value);
if (tensor->is_dist_tensor() == false) {
return false;
}

// check expected_
auto expected_dist_tensor = get_dist_tensor_from_py_object(expected_);
if (expected_dist_tensor == nullptr) {
return false;
}

auto dist_tensor = get_dist_tensor_from_tensor(*tensor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有 expected 应该在构建阶段准备好,而不是在 check 时候再准备

if (dist_tensor == nullptr) {
return false;
}

auto expected_dist_mesh = expected_dist_tensor->process_mesh();
auto dist_mesh = dist_tensor->process_mesh();

// mesh.shape
if (expected_dist_mesh.shape() != dist_mesh.shape()) {
return false;
}

// mesh.process_ids
if (expected_dist_mesh.process_ids() != dist_mesh.process_ids()) {
return false;
}

// dims_mapping

// local_shape
auto local_shape = dist_tensor->value();
auto expected_local_shape = expected_dist_tensor->value();
if (local_shape.dims() != expected_local_shape.dims() ||
local_shape.numel() != expected_local_shape.numel() ||
local_shape.layout() != expected_local_shape.layout() ||
local_shape.dtype() != expected_local_shape.dtype() ||
local_shape.offset() != expected_local_shape.offset()) {
return false;
}

return true;
}

PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; }
std::string ConstantExprNode::stringify(int indent) {
return py::str(value_ptr_);
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/utils/pybind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"

namespace py = pybind11;
#define PYBIND11_DETAILED_ERROR_MESSAGES
Expand Down Expand Up @@ -266,6 +267,20 @@ class WeakRefMatchGuard : public GuardBase {
PyObject* expected_;
};

class TensorDistMatchGuard : public GuardBase {
public:
explicit TensorDistMatchGuard(const py::object& obj) : expected_(obj.ptr()) {
Py_INCREF(expected_);
}

~TensorDistMatchGuard() override { Py_DECREF(expected_); }
bool check(PyObject* value) override;
std::string get_guard_name() const override { return "TensorDistMatchGuard"; }

private:
PyObject* expected_;
};

class DummyGuard : public GuardBase {
public:
bool check(PyObject* value) override { return true; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ def make_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
],
),
# TODO(zrr1999): use TensorMetaMatchGuard to support dist_info check
paddle.framework.core.GuardNode(
paddle.framework.core.TensorDistMatchGuard(self.meta.dist_info),
[expr_node],
),
]

@check_guard
Expand Down
7 changes: 7 additions & 0 deletions test/sot/test_faster_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def test_func():
self.assertFalse(guard_object.check(1))
self.assertFalse(guard_object.check("1"))

# def test_tensor_is_dist_guard(self):
# tensor = paddle.randn([2, 3])
# guard_tensor_is_dist = paddle.framework.core.TensorDistMatchGuard(
# tensor
# )
# self.assertTrue(guard_tensor_is_dist.check(tensor))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块怎么不解开?参考 test/sot/test_sot_distribution.py 的测试条件



class TestFasterGuardGroup(unittest.TestCase):
def test_guard_group(self):
Expand Down