@@ -2005,48 +2005,52 @@ PyObject* GetEmptyTensorsWithVarDesc(PyObject* self, PyObject* args) {
2005
2005
2006
2006
paddle::Tensor CreateTensorFromValue (const pir::Value& value) {
2007
2007
auto tensor = paddle::Tensor ();
2008
-
2009
- auto dims = phi::vectorize (GetValueDims (value));
2010
- auto ddims = phi::make_ddim (dims);
2011
- if (auto name = pir::utils::name_analysis::TryGetValueFirstName (value)) {
2012
- tensor.set_name (name.value ());
2013
- }
2014
2008
auto autograd_meta = egr::EagerUtils::autograd_meta (&tensor);
2015
2009
autograd_meta->SetPersistable (false );
2016
2010
autograd_meta->SetStopGradient (GetValueBoolAttr (value, kAttrStopGradients ));
2017
2011
2018
- if (value.type ().isa <paddle::dialect::DenseTensorType>()) {
2019
- // TODO(jiabin): Maybe support LegacyLoD later
2020
- std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr ;
2021
- auto dtype = paddle::dialect::TransToPhiDataType (
2022
- value.type ().dyn_cast <paddle::dialect::DenseTensorType>().dtype ());
2023
-
2024
- if (dims.size () == 1 && dims[0 ] == 0 ) {
2025
- std::shared_ptr<phi::Allocation> allocation_ptr = nullptr ;
2026
- dense_tensor = std::make_shared<phi::DenseTensor>(
2027
- allocation_ptr, phi::DenseTensorMeta (dtype, ddims));
2028
- } else {
2029
- // TODO(dev): we need enhance check for ddims.
2030
- dense_tensor = std::make_shared<phi::DenseTensor>(
2031
- std::make_shared<phi::Allocation>(),
2032
- phi::DenseTensorMeta (dtype, ddims));
2033
- }
2012
+ if (value.impl () == nullptr || !value.type ()) {
2013
+ // do-nothing, just skip the Value with nullptr
2014
+ } else {
2015
+ auto dims = phi::vectorize (GetValueDims (value));
2016
+ auto ddims = phi::make_ddim (dims);
2017
+ if (auto name = pir::utils::name_analysis::TryGetValueFirstName (value)) {
2018
+ tensor.set_name (name.value ());
2019
+ }
2020
+
2021
+ if (value.type ().isa <paddle::dialect::DenseTensorType>()) {
2022
+ // TODO(jiabin): Maybe support LegacyLoD later
2023
+ std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr ;
2024
+ auto dtype = paddle::dialect::TransToPhiDataType (
2025
+ value.type ().dyn_cast <paddle::dialect::DenseTensorType>().dtype ());
2026
+
2027
+ if (dims.size () == 1 && dims[0 ] == 0 ) {
2028
+ std::shared_ptr<phi::Allocation> allocation_ptr = nullptr ;
2029
+ dense_tensor = std::make_shared<phi::DenseTensor>(
2030
+ allocation_ptr, phi::DenseTensorMeta (dtype, ddims));
2031
+ } else {
2032
+ // TODO(dev): we need enhance check for ddims.
2033
+ dense_tensor = std::make_shared<phi::DenseTensor>(
2034
+ std::make_shared<phi::Allocation>(),
2035
+ phi::DenseTensorMeta (dtype, ddims));
2036
+ }
2034
2037
2035
- if (value.type ().isa <paddle::dialect::DistDenseTensorType>()) {
2036
- paddle::dialect::DistDenseTensorType value_type =
2037
- value.type ().dyn_cast <paddle::dialect::DistDenseTensorType>();
2038
- auto pir_attr = value_type.tensor_dist_attr ();
2039
- auto mesh = pir_attr.process_mesh_attr ().process_mesh ();
2040
- auto placements = pir_attr.placements ();
2041
- tensor.set_impl (std::make_shared<phi::distributed::DistTensor>(
2042
- dense_tensor, mesh, placements));
2043
- } else {
2044
- tensor.set_impl (dense_tensor);
2038
+ if (value.type ().isa <paddle::dialect::DistDenseTensorType>()) {
2039
+ paddle::dialect::DistDenseTensorType value_type =
2040
+ value.type ().dyn_cast <paddle::dialect::DistDenseTensorType>();
2041
+ auto pir_attr = value_type.tensor_dist_attr ();
2042
+ auto mesh = pir_attr.process_mesh_attr ().process_mesh ();
2043
+ auto placements = pir_attr.placements ();
2044
+ tensor.set_impl (std::make_shared<phi::distributed::DistTensor>(
2045
+ dense_tensor, mesh, placements));
2046
+ } else {
2047
+ tensor.set_impl (dense_tensor);
2048
+ }
2049
+ } else if (value.type ().isa <paddle::dialect::SelectedRowsType>()) {
2050
+ std::shared_ptr<phi::SelectedRows> selected_rows_tensor =
2051
+ std::make_shared<phi::SelectedRows>();
2052
+ tensor.set_impl (selected_rows_tensor);
2045
2053
}
2046
- } else if (value.type ().isa <paddle::dialect::SelectedRowsType>()) {
2047
- std::shared_ptr<phi::SelectedRows> selected_rows_tensor =
2048
- std::make_shared<phi::SelectedRows>();
2049
- tensor.set_impl (selected_rows_tensor);
2050
2054
}
2051
2055
2052
2056
if (!autograd_meta->GetMutableGradNode ()) {
@@ -2063,29 +2067,28 @@ PyObject* GetEmptyTensorsWithValue(PyObject* self, PyObject* args) {
2063
2067
2064
2068
auto value_list = PyTuple_GetItem (args, 0 );
2065
2069
2070
+ auto CreateTensorFromValueWithCache =
2071
+ [&out_tensor_map](const pir::Value& value) {
2072
+ if (out_tensor_map.find (value) == out_tensor_map.end ()) {
2073
+ paddle::Tensor tensor = CreateTensorFromValue (value);
2074
+ out_tensor_map[value] = tensor;
2075
+ return tensor;
2076
+ } else {
2077
+ return out_tensor_map[value];
2078
+ }
2079
+ };
2080
+
2066
2081
if (PyList_Check (value_list)) {
2067
2082
Py_ssize_t len = PyList_Size (value_list);
2068
2083
for (Py_ssize_t i = 0 ; i < len; i++) {
2069
2084
auto value = PyObjectCast<pir::Value>(PyList_GetItem (value_list, i));
2070
- if (out_tensor_map.find (value) == out_tensor_map.end ()) {
2071
- paddle::Tensor tensor = CreateTensorFromValue (value);
2072
- out_tensor_map[value] = tensor;
2073
- result.emplace_back (tensor);
2074
- } else {
2075
- result.emplace_back (out_tensor_map[value]);
2076
- }
2085
+ result.emplace_back (CreateTensorFromValueWithCache (value));
2077
2086
}
2078
2087
} else if (PyTuple_Check (value_list)) {
2079
2088
Py_ssize_t len = PyTuple_Size (value_list);
2080
2089
for (Py_ssize_t i = 0 ; i < len; i++) {
2081
2090
auto value = PyObjectCast<pir::Value>(PyTuple_GetItem (value_list, i));
2082
- if (out_tensor_map.find (value) == out_tensor_map.end ()) {
2083
- paddle::Tensor tensor = CreateTensorFromValue (value);
2084
- out_tensor_map[value] = tensor;
2085
- result.emplace_back (tensor);
2086
- } else {
2087
- result.emplace_back (out_tensor_map[value]);
2088
- }
2091
+ result.emplace_back (CreateTensorFromValueWithCache (value));
2089
2092
}
2090
2093
} else if (value_list != Py_None) {
2091
2094
PADDLE_THROW (common::errors::InvalidArgument (
0 commit comments