@@ -2962,7 +2962,8 @@ static PyObject* tensor_method_strides(TensorObject* self,
2962
2962
PyObject* kwargs) {
2963
2963
EAGER_TRY
2964
2964
std::vector<int64_t > value;
2965
- if (!self->tensor .defined () || !self->tensor .is_dense_tensor ()) {
2965
+ if (!self->tensor .defined () ||
2966
+ (!self->tensor .is_dense_tensor () && !self->tensor .is_dist_tensor ())) {
2966
2967
return ToPyObject (value);
2967
2968
}
2968
2969
auto stride = self->tensor .strides ();
@@ -3002,20 +3003,24 @@ static PyObject* tensor_contiguous(TensorObject* self,
3002
3003
PyObject* args,
3003
3004
PyObject* kwargs) {
3004
3005
EAGER_TRY
3005
- if (self->tensor .is_dense_tensor ()) {
3006
- auto dense_tensor =
3007
- std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor .impl ());
3006
+ if (self->tensor .is_dense_tensor () || self->tensor .is_dist_tensor ()) {
3007
+ phi::DenseTensor* dense_tensor = nullptr ;
3008
+ if (self->tensor .is_dist_tensor ()) {
3009
+ dense_tensor =
3010
+ static_cast <phi::distributed::DistTensor*>(self->tensor .impl ().get ())
3011
+ ->unsafe_mutable_value ();
3012
+ } else {
3013
+ dense_tensor = static_cast <phi::DenseTensor*>(self->tensor .impl ().get ());
3014
+ }
3008
3015
if (dense_tensor->meta ().is_contiguous ()) {
3009
3016
Py_INCREF (self);
3010
3017
return reinterpret_cast <PyObject*>(self);
3011
3018
} else {
3012
3019
eager_gil_scoped_release guard;
3013
- self->tensor .set_impl (std::make_shared<phi::DenseTensor>(std::move (
3014
- paddle::experimental::Trans2Contiguous (*(dense_tensor.get ())))));
3020
+ *dense_tensor = paddle::experimental::Trans2Contiguous (*dense_tensor);
3015
3021
Py_INCREF (self);
3016
3022
return reinterpret_cast <PyObject*>(self);
3017
3023
}
3018
-
3019
3024
} else {
3020
3025
Py_INCREF (self);
3021
3026
return reinterpret_cast <PyObject*>(self);
@@ -3050,6 +3055,11 @@ static PyObject* tensor_is_contiguous(TensorObject* self,
3050
3055
auto dense_tensor =
3051
3056
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor .impl ());
3052
3057
return ToPyObject (dense_tensor->meta ().is_contiguous ());
3058
+ } else if (self->tensor .is_dist_tensor ()) {
3059
+ auto dense_tensor = std::dynamic_pointer_cast<phi::distributed::DistTensor>(
3060
+ self->tensor .impl ())
3061
+ ->unsafe_mutable_value ();
3062
+ return ToPyObject (dense_tensor->meta ().is_contiguous ());
3053
3063
} else {
3054
3064
return ToPyObject (true );
3055
3065
}
@@ -3074,19 +3084,27 @@ static PyObject* tensor_method__uva(TensorObject* self,
3074
3084
PyObject* kwargs) {
3075
3085
EAGER_TRY
3076
3086
VLOG (4 ) << " Running in tensor_method__uva." ;
3077
- PADDLE_ENFORCE_EQ (self->tensor .is_dense_tensor (),
3078
- true ,
3079
- platform::errors::InvalidArgument (
3080
- " Unified virtual addressing only support "
3081
- " DenseTensor currently." ));
3087
+ PADDLE_ENFORCE_EQ (
3088
+ self->tensor .is_dense_tensor () || self->tensor .is_dist_tensor (),
3089
+ true ,
3090
+ platform::errors::InvalidArgument (
3091
+ " Unified virtual addressing only support "
3092
+ " DenseTensor and DistTensor currently." ));
3082
3093
PADDLE_ENFORCE_EQ (platform::is_cpu_place (self->tensor .place ()),
3083
3094
true ,
3084
3095
platform::errors::InvalidArgument (
3085
3096
" Unified virtual addressing only support "
3086
3097
" CPU Tensor currently." ));
3087
3098
int device_id = pybind::CastPyArg2AttrLong (PyTuple_GET_ITEM (args, 0 ), 0 );
3088
- auto * self_tensor = static_cast <phi::DenseTensor*>(self->tensor .impl ().get ());
3089
- tensor_uva (self_tensor, device_id);
3099
+ phi::DenseTensor* dense_tensor = nullptr ;
3100
+ if (self->tensor .is_dist_tensor ()) {
3101
+ dense_tensor =
3102
+ static_cast <phi::distributed::DistTensor*>(self->tensor .impl ().get ())
3103
+ ->unsafe_mutable_value ();
3104
+ } else {
3105
+ dense_tensor = static_cast <phi::DenseTensor*>(self->tensor .impl ().get ());
3106
+ }
3107
+ tensor_uva (dense_tensor, device_id);
3090
3108
3091
3109
RETURN_PY_NONE
3092
3110
0 commit comments