-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[SOT][Faster Guard] support TensorVariable
Dist check
#72327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
cfc6506
3e65c6c
303702f
fd714ac
5df2a20
d7ba019
7ded081
ac72e1e
29e4f91
c075166
5540a86
453ec87
134da0a
9e32db3
d09beb9
d27736c
0486513
f4ca2bc
1c88504
811cec3
0859dec
1af2954
3219385
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,11 @@ void BindGuard(pybind11::module *m) { | |
py::class_<DummyGuard, GuardBase, std::shared_ptr<DummyGuard>>( | ||
*m, "DummyGuard", R"DOC(DummyGuard Class.)DOC") | ||
.def(py::init<>()); | ||
py::class_<TensorDistMatchGuard, | ||
GuardBase, | ||
std::shared_ptr<TensorDistMatchGuard>>( | ||
*m, "TensorDistMatchGuard", R"DOC(TensorDistMatchGuard Class.)DOC") | ||
.def(py::init<const py::object &>(), py::arg("tensor")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. init 传的不是 dist info 吗?为什么叫 tensor? |
||
|
||
m->def( | ||
"merge_guard", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/pybind/sot/guards.h" | ||
#include <optional> | ||
#include "paddle/fluid/eager/utils.h" | ||
#include "paddle/phi/api/include/tensor.h" | ||
|
||
#if SOT_IS_SUPPORTED | ||
|
@@ -33,6 +34,12 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) { | |
#define Py_IsNone(x) ((x) == Py_None) | ||
#endif | ||
|
||
#define CheckTensorFromPyObject(value) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 宏的话用大写,不然这样看不出来是宏,仅仅是函数调用看不出来这里会产生控制流 |
||
auto tensor = GetTensorFromPyObject(value); \ | ||
if (!tensor) { \ | ||
return false; \ | ||
} | ||
|
||
#define HANDLE_NULL_VALUE(value) \ | ||
if ((value) == NULL) { \ | ||
PyErr_Clear(); \ | ||
|
@@ -105,20 +112,14 @@ bool LengthMatchGuard::check(PyObject* value) { | |
} | ||
|
||
bool DtypeMatchGuard::check(PyObject* value) { | ||
auto tensor = GetTensorFromPyObject(value); | ||
if (!tensor) { | ||
return false; | ||
} | ||
CheckTensorFromPyObject(value); | ||
auto dtype = tensor->type(); | ||
return phi::TransToProtoVarType(dtype) == expected_; | ||
} | ||
|
||
bool ShapeMatchGuard::check(PyObject* value) { | ||
HANDLE_NULL_VALUE(value); | ||
auto tensor = GetTensorFromPyObject(value); | ||
if (!tensor) { | ||
return false; | ||
} | ||
CheckTensorFromPyObject(value); | ||
auto shape = tensor->shape(); | ||
if (shape.size() != expected_.size()) { | ||
return false; | ||
|
@@ -201,6 +202,73 @@ bool WeakRefMatchGuard::check(PyObject* value) { | |
#endif | ||
} | ||
|
||
phi::distributed::DistTensor* get_dist_tensor_from_py_object(PyObject* obj) { | ||
if (paddle::pybind::PyCheckTensor(obj)) { | ||
auto tensor = reinterpret_cast<paddle::pybind::TensorObject*>(obj)->tensor; | ||
if (tensor.is_dist_tensor()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分逻辑为啥不复用 |
||
return static_cast<phi::distributed::DistTensor*>(tensor.impl().get()); | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
phi::distributed::DistTensor* get_dist_tensor_from_tensor( | ||
const paddle::Tensor& tensor) { | ||
if (tensor.is_dist_tensor()) { | ||
return static_cast<phi::distributed::DistTensor*>(tensor.impl().get()); | ||
} | ||
return nullptr; | ||
} | ||
|
||
bool TensorDistMatchGuard::check(PyObject* value) { | ||
if (value == NULL && expected_ == NULL) { | ||
return true; | ||
} | ||
CheckTensorFromPyObject(value); | ||
if (tensor->is_dist_tensor() == false) { | ||
return false; | ||
} | ||
|
||
// check expected_ | ||
auto expected_dist_tensor = get_dist_tensor_from_py_object(expected_); | ||
if (expected_dist_tensor == nullptr) { | ||
return false; | ||
} | ||
|
||
auto dist_tensor = get_dist_tensor_from_tensor(*tensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 所有 expected 应该在构建阶段准备好,而不是在 check 时候再准备 |
||
if (dist_tensor == nullptr) { | ||
return false; | ||
} | ||
|
||
auto expected_dist_mesh = expected_dist_tensor->process_mesh(); | ||
auto dist_mesh = dist_tensor->process_mesh(); | ||
|
||
// mesh.shape | ||
if (expected_dist_mesh.shape() != dist_mesh.shape()) { | ||
return false; | ||
} | ||
|
||
// mesh.process_ids | ||
if (expected_dist_mesh.process_ids() != dist_mesh.process_ids()) { | ||
return false; | ||
} | ||
|
||
// dims_mapping | ||
|
||
// local_shape | ||
auto local_shape = dist_tensor->value(); | ||
auto expected_local_shape = expected_dist_tensor->value(); | ||
if (local_shape.dims() != expected_local_shape.dims() || | ||
local_shape.numel() != expected_local_shape.numel() || | ||
local_shape.layout() != expected_local_shape.layout() || | ||
local_shape.dtype() != expected_local_shape.dtype() || | ||
local_shape.offset() != expected_local_shape.offset()) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
PyObject* ConstantExprNode::eval(FrameProxy* frame) { return value_ptr_; } | ||
std::string ConstantExprNode::stringify(int indent) { | ||
return py::str(value_ptr_); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,6 +202,13 @@ def test_func(): | |
self.assertFalse(guard_object.check(1)) | ||
self.assertFalse(guard_object.check("1")) | ||
|
||
# def test_tensor_is_dist_guard(self): | ||
# tensor = paddle.randn([2, 3]) | ||
# guard_tensor_is_dist = paddle.framework.core.TensorDistMatchGuard( | ||
# tensor | ||
# ) | ||
# self.assertTrue(guard_tensor_is_dist.check(tensor)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块怎么不解开?参考 |
||
|
||
|
||
class TestFasterGuardGroup(unittest.TestCase): | ||
def test_guard_group(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该叫
TensorDistMetaMatchGuard