Skip to content

Commit 5bcef40

Browse files
【dygraph】Set value fp8 (#73052)
* add float8 support for concat * add concat grad support fp8 * set_value support fp8 * modify python api * modify ci test * Apply suggestions from code review
1 parent 18fdfc1 commit 5bcef40

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ typedef SSIZE_T ssize_t;
2929
#include "paddle/fluid/eager/hooks.h"
3030
#include "paddle/fluid/eager/utils.h"
3131
#include "paddle/fluid/framework/convert_utils.h"
32+
#include "paddle/fluid/framework/tensor_util.h"
3233
#include "paddle/fluid/platform/enforce.h"
3334
#include "paddle/fluid/pybind/eager.h"
3435
#include "paddle/fluid/pybind/eager_utils.h"
@@ -1398,6 +1399,61 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
13981399
EAGER_CATCH_AND_THROW_RETURN_NULL
13991400
}
14001401

1402+
static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
1403+
PyObject* args,
1404+
PyObject* kwargs) {
1405+
EAGER_TRY
1406+
auto& value = GetTensorFromArgs("set_tensor", "value", args, 0, false);
1407+
if (!value.defined()) {
1408+
PADDLE_THROW(
1409+
common::errors::Unavailable("The `set_tensor()` method of (Dist)Tensor "
1410+
"get a non initialized src value"));
1411+
} else if (value.is_dense_tensor()) {
1412+
auto* src_tensor = static_cast<phi::DenseTensor*>(value.impl().get());
1413+
if (self->tensor.is_dense_tensor()) {
1414+
auto* dst_tensor =
1415+
static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1416+
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
1417+
} else {
1418+
PADDLE_THROW(common::errors::Unavailable(
1419+
"The `set_tensor()` method of non DenseTensor get a DenseTensor src "
1420+
"value"));
1421+
}
1422+
1423+
} else if (value.is_dist_tensor()) {
1424+
#ifdef PADDLE_WITH_DISTRIBUTE
1425+
auto* src_tensor =
1426+
static_cast<phi::distributed::DistTensor*>(value.impl().get());
1427+
if (self->tensor.is_dist_tensor()) {
1428+
auto* dst_tensor =
1429+
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
1430+
framework::TensorCopy(*(src_tensor->unsafe_mutable_value()),
1431+
dst_tensor->place(),
1432+
dst_tensor->unsafe_mutable_value());
1433+
1434+
// TensorCopyFrom(dst_tensor->unsafe_mutable_value(),
1435+
// *(src_tensor->unsafe_mutable_value()), dst_tensor->place(), -1);
1436+
} else {
1437+
PADDLE_THROW(
1438+
common::errors::Unavailable("The `set_tensor()` method of non "
1439+
"DistTensor get a DistTensor src value"));
1440+
}
1441+
#else
1442+
PADDLE_THROW(common::errors::Unavailable(
1443+
"The `set_tensor()` method of (Dist)Tensor is not supported in the "
1444+
"current PaddlePaddle, please recompile and installPaddlePaddle "
1445+
"with the option of `WITH_DISTRIBUTE=ON`."));
1446+
#endif
1447+
1448+
} else {
1449+
PADDLE_THROW(common::errors::Unavailable(
1450+
"The `set_tensor()` method of (Dist)Tensor get a non "
1451+
"DenseTensor/DistTensor src value"));
1452+
}
1453+
RETURN_PY_NONE
1454+
EAGER_CATCH_AND_THROW_RETURN_NULL
1455+
}
1456+
14011457
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
14021458
PyObject* args,
14031459
PyObject* kwargs) {
@@ -3643,6 +3699,10 @@ PyMethodDef variable_methods[] = { // NOLINT
36433699
(PyCFunction)(void (*)())tensor_method__get_tensor_from_selected_rows,
36443700
METH_VARARGS | METH_KEYWORDS,
36453701
nullptr},
3702+
{"set_tensor",
3703+
(PyCFunction)(void (*)())tensor_method_set_underline_tensor,
3704+
METH_VARARGS | METH_KEYWORDS,
3705+
nullptr},
36463706
{"_getitem_dygraph",
36473707
(PyCFunction)(void (*)())tensor__getitem_dygraph,
36483708
METH_VARARGS | METH_KEYWORDS,

python/paddle/base/dygraph/tensor_patch_methods.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,17 @@ def set_value(
258258
self.value().process_mesh,
259259
self.value().placements,
260260
)
261-
self.value().get_tensor().set(value.get_tensor())
261+
if isinstance(value, paddle.Tensor):
262+
self.value().set_tensor(value)
263+
else:
264+
self.value().get_tensor().set(value.get_tensor())
262265
return
263-
self.value().get_tensor().set(
264-
value, framework._current_expected_place()
265-
)
266+
if isinstance(value, paddle.Tensor):
267+
self.value().set_tensor(value)
268+
else:
269+
self.value().get_tensor().set(
270+
value, framework._current_expected_place()
271+
)
266272

267273
@framework.dygraph_only
268274
def backward(

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
'rows',
7171
'set_string_list',
7272
'set_value',
73+
'set_tensor',
7374
'set_vocab',
7475
'strides',
7576
'to_sparse_coo',

0 commit comments

Comments
 (0)