Skip to content

Commit 25cd44a

Browse files
[AutoParallel] Eager method support autoparallel3 (PaddlePaddle#58476)
* PHI copy support auto parallel
1 parent 7ad9a5f commit 25cd44a

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,7 +2962,8 @@ static PyObject* tensor_method_strides(TensorObject* self,
29622962
PyObject* kwargs) {
29632963
EAGER_TRY
29642964
std::vector<int64_t> value;
2965-
if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
2965+
if (!self->tensor.defined() ||
2966+
(!self->tensor.is_dense_tensor() && !self->tensor.is_dist_tensor())) {
29662967
return ToPyObject(value);
29672968
}
29682969
auto stride = self->tensor.strides();
@@ -3002,20 +3003,24 @@ static PyObject* tensor_contiguous(TensorObject* self,
30023003
PyObject* args,
30033004
PyObject* kwargs) {
30043005
EAGER_TRY
3005-
if (self->tensor.is_dense_tensor()) {
3006-
auto dense_tensor =
3007-
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
3006+
if (self->tensor.is_dense_tensor() || self->tensor.is_dist_tensor()) {
3007+
phi::DenseTensor* dense_tensor = nullptr;
3008+
if (self->tensor.is_dist_tensor()) {
3009+
dense_tensor =
3010+
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
3011+
->unsafe_mutable_value();
3012+
} else {
3013+
dense_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
3014+
}
30083015
if (dense_tensor->meta().is_contiguous()) {
30093016
Py_INCREF(self);
30103017
return reinterpret_cast<PyObject*>(self);
30113018
} else {
30123019
eager_gil_scoped_release guard;
3013-
self->tensor.set_impl(std::make_shared<phi::DenseTensor>(std::move(
3014-
paddle::experimental::Trans2Contiguous(*(dense_tensor.get())))));
3020+
*dense_tensor = paddle::experimental::Trans2Contiguous(*dense_tensor);
30153021
Py_INCREF(self);
30163022
return reinterpret_cast<PyObject*>(self);
30173023
}
3018-
30193024
} else {
30203025
Py_INCREF(self);
30213026
return reinterpret_cast<PyObject*>(self);
@@ -3050,6 +3055,11 @@ static PyObject* tensor_is_contiguous(TensorObject* self,
30503055
auto dense_tensor =
30513056
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
30523057
return ToPyObject(dense_tensor->meta().is_contiguous());
3058+
} else if (self->tensor.is_dist_tensor()) {
3059+
auto dense_tensor = std::dynamic_pointer_cast<phi::distributed::DistTensor>(
3060+
self->tensor.impl())
3061+
->unsafe_mutable_value();
3062+
return ToPyObject(dense_tensor->meta().is_contiguous());
30533063
} else {
30543064
return ToPyObject(true);
30553065
}
@@ -3074,19 +3084,27 @@ static PyObject* tensor_method__uva(TensorObject* self,
30743084
PyObject* kwargs) {
30753085
EAGER_TRY
30763086
VLOG(4) << "Running in tensor_method__uva.";
3077-
PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
3078-
true,
3079-
platform::errors::InvalidArgument(
3080-
"Unified virtual addressing only support "
3081-
"DenseTensor currently."));
3087+
PADDLE_ENFORCE_EQ(
3088+
self->tensor.is_dense_tensor() || self->tensor.is_dist_tensor(),
3089+
true,
3090+
platform::errors::InvalidArgument(
3091+
"Unified virtual addressing only support "
3092+
"DenseTensor and DistTensor currently."));
30823093
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
30833094
true,
30843095
platform::errors::InvalidArgument(
30853096
"Unified virtual addressing only support "
30863097
"CPU Tensor currently."));
30873098
int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
3088-
auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
3089-
tensor_uva(self_tensor, device_id);
3099+
phi::DenseTensor* dense_tensor = nullptr;
3100+
if (self->tensor.is_dist_tensor()) {
3101+
dense_tensor =
3102+
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
3103+
->unsafe_mutable_value();
3104+
} else {
3105+
dense_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
3106+
}
3107+
tensor_uva(dense_tensor, device_id);
30903108

30913109
RETURN_PY_NONE
30923110

paddle/phi/api/lib/tensor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ std::vector<int64_t> Tensor::shape() const {
113113
const phi::DDim &Tensor::strides() const {
114114
if (is_dense_tensor()) {
115115
return static_cast<phi::DenseTensor *>(impl_.get())->strides();
116+
} else if (is_dist_tensor()) {
117+
return static_cast<phi::distributed::DistTensor *>(impl_.get())
118+
->unsafe_mutable_value()
119+
->strides();
116120
} else {
117121
PADDLE_THROW(phi::errors::Unimplemented(
118122
"Only support strides operation on DenseTensor now."));

python/paddle/distributed/auto_parallel/api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def shard_tensor(
136136
>>> print(d_tensor)
137137
138138
"""
139+
if place is None:
140+
place = paddle.framework._current_expected_place()
141+
place = paddle.framework._get_paddle_place(place)
142+
139143
# 1. create dense tensor
140144
# `paddle.to_tensor` supports both dynamic and static mode
141145
tensor = paddle.to_tensor(
@@ -154,7 +158,7 @@ def shard_tensor(
154158
tensor, dist_attr=dist_attr, **tensor.__dict__
155159
)
156160
else:
157-
return paddle.Tensor(tensor, dist_attr=dist_attr)
161+
return paddle.Tensor(tensor, dist_attr=dist_attr, place=place)
158162
else:
159163
# TODO(zhiqiu): we need to refine the static shard_tensor
160164
return shard_tensor_static(

test/auto_parallel/test_semi_auto_parallel_functional_in_single_card.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import unittest
1616

17+
import numpy as np
18+
1719
import paddle
1820
import paddle.distributed as dist
1921

@@ -84,6 +86,30 @@ def test_tensor__is_shared_buffer_with(self):
8486
dist_tensor._share_buffer_to(to)
8587
self.assertTrue(dist_tensor._is_shared_buffer_with(to))
8688

89+
def test_tensor_strides(self):
90+
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
91+
dense_tensor = paddle.randn([10, 20])
92+
dense_tensor = dense_tensor.reshape([20, 10])
93+
dist_tensor = dist.shard_tensor(
94+
dense_tensor,
95+
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]),
96+
)
97+
strides = dist_tensor.get_strides()
98+
is_contiguous = dist_tensor.is_contiguous()
99+
dist_tensor = dist_tensor.contiguous()
100+
101+
def test_tensor_uva(self):
102+
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
103+
place = paddle.CPUPlace()
104+
np_value = np.random.random(size=[10, 30]).astype('float32')
105+
dense_tensor = paddle.to_tensor(np_value, place=place)
106+
dist_tensor = dist.shard_tensor(
107+
dense_tensor,
108+
place=place,
109+
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]),
110+
)
111+
dist_tensor._uva()
112+
87113

88114
if __name__ == "__main__":
89115
unittest.main()

0 commit comments

Comments
 (0)