Skip to content

Commit 28eae79

Browse files
authored
[SOT][PIR] Support optional tensor output (#72737)
1 parent 1a06c9e commit 28eae79

File tree

10 files changed

+395
-215
lines changed

10 files changed

+395
-215
lines changed

paddle/fluid/pybind/eager_utils.cc

+53-50
Original file line numberDiff line numberDiff line change
@@ -2005,48 +2005,52 @@ PyObject* GetEmptyTensorsWithVarDesc(PyObject* self, PyObject* args) {
20052005

20062006
paddle::Tensor CreateTensorFromValue(const pir::Value& value) {
20072007
auto tensor = paddle::Tensor();
2008-
2009-
auto dims = phi::vectorize(GetValueDims(value));
2010-
auto ddims = phi::make_ddim(dims);
2011-
if (auto name = pir::utils::name_analysis::TryGetValueFirstName(value)) {
2012-
tensor.set_name(name.value());
2013-
}
20142008
auto autograd_meta = egr::EagerUtils::autograd_meta(&tensor);
20152009
autograd_meta->SetPersistable(false);
20162010
autograd_meta->SetStopGradient(GetValueBoolAttr(value, kAttrStopGradients));
20172011

2018-
if (value.type().isa<paddle::dialect::DenseTensorType>()) {
2019-
// TODO(jiabin): Maybe support LegacyLoD later
2020-
std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr;
2021-
auto dtype = paddle::dialect::TransToPhiDataType(
2022-
value.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype());
2023-
2024-
if (dims.size() == 1 && dims[0] == 0) {
2025-
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
2026-
dense_tensor = std::make_shared<phi::DenseTensor>(
2027-
allocation_ptr, phi::DenseTensorMeta(dtype, ddims));
2028-
} else {
2029-
// TODO(dev): we need enhance check for ddims.
2030-
dense_tensor = std::make_shared<phi::DenseTensor>(
2031-
std::make_shared<phi::Allocation>(),
2032-
phi::DenseTensorMeta(dtype, ddims));
2033-
}
2012+
if (value.impl() == nullptr || !value.type()) {
2013+
// do-nothing, just skip the Value with nullptr
2014+
} else {
2015+
auto dims = phi::vectorize(GetValueDims(value));
2016+
auto ddims = phi::make_ddim(dims);
2017+
if (auto name = pir::utils::name_analysis::TryGetValueFirstName(value)) {
2018+
tensor.set_name(name.value());
2019+
}
2020+
2021+
if (value.type().isa<paddle::dialect::DenseTensorType>()) {
2022+
// TODO(jiabin): Maybe support LegacyLoD later
2023+
std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr;
2024+
auto dtype = paddle::dialect::TransToPhiDataType(
2025+
value.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype());
2026+
2027+
if (dims.size() == 1 && dims[0] == 0) {
2028+
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
2029+
dense_tensor = std::make_shared<phi::DenseTensor>(
2030+
allocation_ptr, phi::DenseTensorMeta(dtype, ddims));
2031+
} else {
2032+
// TODO(dev): we need enhance check for ddims.
2033+
dense_tensor = std::make_shared<phi::DenseTensor>(
2034+
std::make_shared<phi::Allocation>(),
2035+
phi::DenseTensorMeta(dtype, ddims));
2036+
}
20342037

2035-
if (value.type().isa<paddle::dialect::DistDenseTensorType>()) {
2036-
paddle::dialect::DistDenseTensorType value_type =
2037-
value.type().dyn_cast<paddle::dialect::DistDenseTensorType>();
2038-
auto pir_attr = value_type.tensor_dist_attr();
2039-
auto mesh = pir_attr.process_mesh_attr().process_mesh();
2040-
auto placements = pir_attr.placements();
2041-
tensor.set_impl(std::make_shared<phi::distributed::DistTensor>(
2042-
dense_tensor, mesh, placements));
2043-
} else {
2044-
tensor.set_impl(dense_tensor);
2038+
if (value.type().isa<paddle::dialect::DistDenseTensorType>()) {
2039+
paddle::dialect::DistDenseTensorType value_type =
2040+
value.type().dyn_cast<paddle::dialect::DistDenseTensorType>();
2041+
auto pir_attr = value_type.tensor_dist_attr();
2042+
auto mesh = pir_attr.process_mesh_attr().process_mesh();
2043+
auto placements = pir_attr.placements();
2044+
tensor.set_impl(std::make_shared<phi::distributed::DistTensor>(
2045+
dense_tensor, mesh, placements));
2046+
} else {
2047+
tensor.set_impl(dense_tensor);
2048+
}
2049+
} else if (value.type().isa<paddle::dialect::SelectedRowsType>()) {
2050+
std::shared_ptr<phi::SelectedRows> selected_rows_tensor =
2051+
std::make_shared<phi::SelectedRows>();
2052+
tensor.set_impl(selected_rows_tensor);
20452053
}
2046-
} else if (value.type().isa<paddle::dialect::SelectedRowsType>()) {
2047-
std::shared_ptr<phi::SelectedRows> selected_rows_tensor =
2048-
std::make_shared<phi::SelectedRows>();
2049-
tensor.set_impl(selected_rows_tensor);
20502054
}
20512055

20522056
if (!autograd_meta->GetMutableGradNode()) {
@@ -2063,29 +2067,28 @@ PyObject* GetEmptyTensorsWithValue(PyObject* self, PyObject* args) {
20632067

20642068
auto value_list = PyTuple_GetItem(args, 0);
20652069

2070+
auto CreateTensorFromValueWithCache =
2071+
[&out_tensor_map](const pir::Value& value) {
2072+
if (out_tensor_map.find(value) == out_tensor_map.end()) {
2073+
paddle::Tensor tensor = CreateTensorFromValue(value);
2074+
out_tensor_map[value] = tensor;
2075+
return tensor;
2076+
} else {
2077+
return out_tensor_map[value];
2078+
}
2079+
};
2080+
20662081
if (PyList_Check(value_list)) {
20672082
Py_ssize_t len = PyList_Size(value_list);
20682083
for (Py_ssize_t i = 0; i < len; i++) {
20692084
auto value = PyObjectCast<pir::Value>(PyList_GetItem(value_list, i));
2070-
if (out_tensor_map.find(value) == out_tensor_map.end()) {
2071-
paddle::Tensor tensor = CreateTensorFromValue(value);
2072-
out_tensor_map[value] = tensor;
2073-
result.emplace_back(tensor);
2074-
} else {
2075-
result.emplace_back(out_tensor_map[value]);
2076-
}
2085+
result.emplace_back(CreateTensorFromValueWithCache(value));
20772086
}
20782087
} else if (PyTuple_Check(value_list)) {
20792088
Py_ssize_t len = PyTuple_Size(value_list);
20802089
for (Py_ssize_t i = 0; i < len; i++) {
20812090
auto value = PyObjectCast<pir::Value>(PyTuple_GetItem(value_list, i));
2082-
if (out_tensor_map.find(value) == out_tensor_map.end()) {
2083-
paddle::Tensor tensor = CreateTensorFromValue(value);
2084-
out_tensor_map[value] = tensor;
2085-
result.emplace_back(tensor);
2086-
} else {
2087-
result.emplace_back(out_tensor_map[value]);
2088-
}
2091+
result.emplace_back(CreateTensorFromValueWithCache(value));
20892092
}
20902093
} else if (value_list != Py_None) {
20912094
PADDLE_THROW(common::errors::InvalidArgument(

paddle/fluid/pybind/pybind.cc

+1-9
Original file line numberDiff line numberDiff line change
@@ -1030,24 +1030,16 @@ void BindDecompRule(pybind11::module *m) {
10301030
int start_index,
10311031
int end_index) {
10321032
VLOG(4) << "[Prim] Bind Decomp sinking_decomp begin.";
1033-
py::list res;
10341033
auto original_insertion_point =
10351034
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
10361035
DecompProgram decomp_object(
10371036
program, src_vars, blacklist, whitelist, start_index, end_index);
10381037
decomp_object.decomp_program();
10391038
std::vector<pir::Value> tar_vars = decomp_object.get_dst_vars();
1040-
for (size_t i = 0; i < tar_vars.size(); ++i) {
1041-
if (!tar_vars[i]) {
1042-
res.append(nullptr);
1043-
} else {
1044-
res.append(tar_vars[i]);
1045-
}
1046-
}
10471039
paddle::dialect::ApiBuilder::Instance().SetInsertionPoint(
10481040
original_insertion_point);
10491041
VLOG(4) << "[Prim] Bind Decomp sinking_decomp end.";
1050-
return res;
1042+
return tar_vars;
10511043
});
10521044

10531045
m->def("call_decomp_rule", [](pir::Operation &fwd_op) {

0 commit comments

Comments
 (0)