@@ -29,6 +29,7 @@ typedef SSIZE_T ssize_t;
29
29
#include " paddle/fluid/eager/hooks.h"
30
30
#include " paddle/fluid/eager/utils.h"
31
31
#include " paddle/fluid/framework/convert_utils.h"
32
+ #include " paddle/fluid/framework/tensor_util.h"
32
33
#include " paddle/fluid/platform/enforce.h"
33
34
#include " paddle/fluid/pybind/eager.h"
34
35
#include " paddle/fluid/pybind/eager_utils.h"
@@ -1398,6 +1399,61 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
1398
1399
EAGER_CATCH_AND_THROW_RETURN_NULL
1399
1400
}
1400
1401
1402
+ static PyObject* tensor_method_set_underline_tensor (TensorObject* self,
1403
+ PyObject* args,
1404
+ PyObject* kwargs) {
1405
+ EAGER_TRY
1406
+ auto & value = GetTensorFromArgs (" set_tensor" , " value" , args, 0 , false );
1407
+ if (!value.defined ()) {
1408
+ PADDLE_THROW (
1409
+ common::errors::Unavailable (" The `set_tensor()` method of (Dist)Tensor "
1410
+ " get a non initialized src value" ));
1411
+ } else if (value.is_dense_tensor ()) {
1412
+ auto * src_tensor = static_cast <phi::DenseTensor*>(value.impl ().get ());
1413
+ if (self->tensor .is_dense_tensor ()) {
1414
+ auto * dst_tensor =
1415
+ static_cast <phi::DenseTensor*>(self->tensor .impl ().get ());
1416
+ framework::TensorCopy (*src_tensor, dst_tensor->place (), dst_tensor);
1417
+ } else {
1418
+ PADDLE_THROW (common::errors::Unavailable (
1419
+ " The `set_tensor()` method of non DenseTensor get a DenseTensor src "
1420
+ " value" ));
1421
+ }
1422
+
1423
+ } else if (value.is_dist_tensor ()) {
1424
+ #ifdef PADDLE_WITH_DISTRIBUTE
1425
+ auto * src_tensor =
1426
+ static_cast <phi::distributed::DistTensor*>(value.impl ().get ());
1427
+ if (self->tensor .is_dist_tensor ()) {
1428
+ auto * dst_tensor =
1429
+ static_cast <phi::distributed::DistTensor*>(self->tensor .impl ().get ());
1430
+ framework::TensorCopy (*(src_tensor->unsafe_mutable_value ()),
1431
+ dst_tensor->place (),
1432
+ dst_tensor->unsafe_mutable_value ());
1433
+
1434
+ // TensorCopyFrom(dst_tensor->unsafe_mutable_value(),
1435
+ // *(src_tensor->unsafe_mutable_value()), dst_tensor->place(), -1);
1436
+ } else {
1437
+ PADDLE_THROW (
1438
+ common::errors::Unavailable (" The `set_tensor()` method of non "
1439
+ " DistTensor get a DistTensor src value" ));
1440
+ }
1441
+ #else
1442
+ PADDLE_THROW (common::errors::Unavailable (
1443
+ " The `set_tensor()` method of (Dist)Tensor is not supported in the "
1444
+ " current PaddlePaddle, please recompile and installPaddlePaddle "
1445
+ " with the option of `WITH_DISTRIBUTE=ON`." ));
1446
+ #endif
1447
+
1448
+ } else {
1449
+ PADDLE_THROW (common::errors::Unavailable (
1450
+ " The `set_tensor()` method of (Dist)Tensor get a non "
1451
+ " DenseTensor/DistTensor src value" ));
1452
+ }
1453
+ RETURN_PY_NONE
1454
+ EAGER_CATCH_AND_THROW_RETURN_NULL
1455
+ }
1456
+
1401
1457
static PyObject* tensor_method_get_underline_selected_rows (TensorObject* self,
1402
1458
PyObject* args,
1403
1459
PyObject* kwargs) {
@@ -3643,6 +3699,10 @@ PyMethodDef variable_methods[] = { // NOLINT
3643
3699
(PyCFunction)(void (*)())tensor_method__get_tensor_from_selected_rows,
3644
3700
METH_VARARGS | METH_KEYWORDS,
3645
3701
nullptr },
3702
+ {" set_tensor" ,
3703
+ (PyCFunction)(void (*)())tensor_method_set_underline_tensor,
3704
+ METH_VARARGS | METH_KEYWORDS,
3705
+ nullptr },
3646
3706
{" _getitem_dygraph" ,
3647
3707
(PyCFunction)(void (*)())tensor__getitem_dygraph,
3648
3708
METH_VARARGS | METH_KEYWORDS,
0 commit comments