Skip to content

Commit 609f55e

Browse files
authored
[Cherry-pick] Fix indexing shape bug and Optimize (#62117)
* tensor_array slice in PIR (#60503) * use slice_array, now will meet error of destory opresult still in use * disable the pir test until the bug fixed * Optimize advanced setting by remove the last set_value (#60771) * pure-advanced setitem will not set_value back * fix multi output in tensor_array_pir * only in dynamic mode * add only advanced-setting case to fix coverage * fast pass for bool setitem (#61021) * fast pass for bool setitem * fix 0-size value case * remove final set_value OP in combined-indexing setting (#60983) * remove setvalue in combined indexing set * using combined-setting case to test * Optimize index put preprocess (#61060) * reduce vector operations when no bool index * reduce vector in index_put * reduce vector operations * no change for value * fix shape error in combine-getitem (#61922)
1 parent f4d9adf commit 609f55e

File tree

7 files changed

+335
-183
lines changed

7 files changed

+335
-183
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 125 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ typedef SSIZE_T ssize_t;
5454
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
5555
#include "paddle/common/ddim.h"
5656
#include "paddle/fluid/eager/amp_utils.h"
57+
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
5758
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
5859
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
5960
#include "paddle/fluid/framework/python_headers.h"
@@ -1359,6 +1360,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13591360
&use_strided_slice);
13601361

13611362
// step2: Dealing with basic indexing
1363+
bool out_is_view = false;
13621364
auto out = getTensorWithBasicIndexing(tensor,
13631365
&slice_axes,
13641366
&slice_starts,
@@ -1367,7 +1369,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13671369
&decrease_axis,
13681370
&none_axes,
13691371
&infer_flags,
1370-
&use_strided_slice);
1372+
&use_strided_slice,
1373+
&out_is_view);
13711374

13721375
if (!has_advanced_index) {
13731376
return ToPyObject(out);
@@ -1386,7 +1389,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13861389
&trans_back_dim,
13871390
&pos_of_new_dim,
13881391
&rank_of_new_dim,
1389-
&trans_dim);
1392+
&trans_dim,
1393+
&out_is_view);
13901394

13911395
if (transed_index.size() == 1 &&
13921396
transed_index[0].dtype() == phi::DataType::BOOL) {
@@ -1416,14 +1420,14 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
14161420

14171421
if (pos_of_new_dim != 0) {
14181422
std::vector<int> perm(out.shape().size(), 0);
1419-
int tmp1 = pos_of_new_dim, tmp2 = 0,
1423+
int tmp1 = rank_of_new_dim, tmp2 = 0,
14201424
tmp3 = pos_of_new_dim + rank_of_new_dim;
14211425
for (int i = 0; i < static_cast<int>(out.shape().size()); ++i) {
1422-
if (i < rank_of_new_dim) {
1426+
if (i < pos_of_new_dim) {
14231427
perm[i] =
1424-
tmp1++; // range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim)
1425-
} else if (i >= rank_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1426-
perm[i] = tmp2++; // range(0, pos_of_new_dim)
1428+
tmp1++; // range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim)
1429+
} else if (i >= pos_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1430+
perm[i] = tmp2++; // range(0, rank_of_new_dim)
14271431
} else {
14281432
perm[i] = tmp3++; // range(pos_of_new_dim + rank_of_new_dim, out.ndim)
14291433
}
@@ -1681,6 +1685,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
16811685
// 3. assign values to the sliced result by index_put OP;
16821686
// 4. transpose back and assign the result to original tensor by set_value
16831687
// OP.
1688+
bool out_is_view = false;
16841689
paddle::Tensor sub_tensor = getTensorWithBasicIndexing(tensor,
16851690
&slice_axes,
16861691
&slice_starts,
@@ -1689,7 +1694,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
16891694
&decrease_axis,
16901695
&none_axes,
16911696
&infer_flags,
1692-
&use_strided_slice);
1697+
&use_strided_slice,
1698+
&out_is_view);
16931699

16941700
std::vector<paddle::Tensor> transed_index;
16951701
std::vector<int> trans_back_dim, trans_dim;
@@ -1705,65 +1711,126 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
17051711
&trans_back_dim,
17061712
&pos_of_new_dim,
17071713
&rank_of_new_dim,
1708-
&trans_dim);
1714+
&trans_dim,
1715+
&out_is_view);
17091716

17101717
// Release gil and do tracing
17111718
py::gil_scoped_release release;
1712-
1713-
if (value_tensor.initialized() &&
1714-
(self->tensor.dtype() != value_tensor.dtype())) {
1715-
if (egr::Controller::Instance().GetAMPLevel() !=
1716-
paddle::imperative::AmpLevel::O0) {
1717-
paddle::small_vector<std::vector<paddle::Tensor>,
1718-
egr::kSlotSmallVectorSize>
1719-
tmps = {{self->tensor}, {value_tensor}};
1720-
auto amp_dtype = egr::GetAmpDestDtype("index_put", tmps);
1721-
self->tensor = egr::EagerAmpAutoCast(
1722-
self->tensor.name(), self->tensor, amp_dtype, "index_put");
1723-
value_tensor = egr::EagerAmpAutoCast(
1724-
value_tensor.name(), value_tensor, amp_dtype, "index_put");
1725-
}
1719+
if (value_tensor.initialized()) {
17261720
if (self->tensor.dtype() != value_tensor.dtype()) {
1727-
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
1721+
if (egr::Controller::Instance().GetAMPLevel() !=
1722+
paddle::imperative::AmpLevel::O0) {
1723+
paddle::small_vector<std::vector<paddle::Tensor>,
1724+
egr::kSlotSmallVectorSize>
1725+
tmps = {{self->tensor}, {value_tensor}};
1726+
auto amp_dtype = egr::GetAmpDestDtype("index_put", tmps);
1727+
self->tensor = egr::EagerAmpAutoCast(
1728+
self->tensor.name(), self->tensor, amp_dtype, "index_put");
1729+
value_tensor = egr::EagerAmpAutoCast(
1730+
value_tensor.name(), value_tensor, amp_dtype, "index_put");
1731+
}
1732+
if (self->tensor.dtype() != value_tensor.dtype()) {
1733+
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
1734+
}
17281735
}
1729-
}
17301736

1731-
if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) {
1732-
value_tensor = transpose_ad_func(value_tensor, trans_dim);
1733-
}
1737+
if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) {
1738+
value_tensor = transpose_ad_func(value_tensor, trans_dim);
1739+
}
17341740

1735-
// TODO(zoooo0820) 1.Using inplace version index_put
1736-
// 2.Remove following code after backward bug fixed.
1737-
transed_sub_tensor = assign_ad_func(transed_sub_tensor);
1741+
const phi::distributed::ProcessMesh* mesh = nullptr;
1742+
if (InputsContainDistTensor(
1743+
&mesh, self->tensor, transed_sub_tensor, value_tensor)) {
1744+
ConvertAllInputsToDistTensor(
1745+
mesh, self->tensor, transed_sub_tensor, value_tensor);
1746+
}
17381747

1739-
const phi::distributed::ProcessMesh* mesh = nullptr;
1740-
if (InputsContainDistTensor(
1741-
&mesh, self->tensor, transed_sub_tensor, value_tensor)) {
1742-
ConvertAllInputsToDistTensor(
1743-
mesh, self->tensor, transed_sub_tensor, value_tensor);
1744-
}
1748+
if (transed_index.size() == 1 &&
1749+
transed_index[0].dtype() == phi::DataType::BOOL &&
1750+
transed_index[0].shape().size() == self->tensor.shape().size()) {
1751+
if (value_tensor.shape() != self->tensor.shape()) {
1752+
value_tensor = expand_ad_func(value_tensor, self->tensor.shape());
1753+
}
1754+
transed_sub_tensor =
1755+
where__ad_func(logical_not_ad_func(transed_index[0]),
1756+
transed_sub_tensor,
1757+
value_tensor);
1758+
} else {
1759+
transed_sub_tensor =
1760+
index_put__ad_func(transed_sub_tensor, transed_index, value_tensor);
1761+
}
17451762

1746-
transed_sub_tensor =
1747-
index_put_ad_func(transed_sub_tensor, transed_index, value_tensor);
1748-
1749-
paddle::Tensor transback_sub_tensor =
1750-
transpose_ad_func(transed_sub_tensor, trans_back_dim);
1751-
1752-
self->tensor = set_value_with_tensor__ad_func(self->tensor,
1753-
transback_sub_tensor,
1754-
slice_starts,
1755-
slice_ends,
1756-
slice_strides,
1757-
slice_axes,
1758-
decrease_axis,
1759-
none_axes);
1760-
if (PyCheckTensor(value_obj)) {
1761-
// pass the stop_gradient from value to tensor.
1762-
// pass stop gradient should be done after CheckInplace in
1763-
// set_value__dygraph_function.
1764-
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
1765-
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
1766-
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
1763+
if (out_is_view) {
1764+
// NOTE(zoooo0820): if out_is_view is true, it is a case of
1765+
// combined-indexing setitem, i.e. firstly we get a view of
1766+
// self->tensor, then modified it with inplace api index_put_ For now,
1767+
// in design of Paddle, the forward result is right. But the backward
1768+
// edge can not be established because the Base Tensor cannot sense
1769+
// whether it has been modified by other operations. Following codes are
1770+
// to add a new node (set_value_with_tensor_grad) to record the backward
1771+
// edge, with out ad_function which needs to do the forward calculation.
1772+
1773+
egr::AutogradMeta* x_autograd_meta =
1774+
egr::EagerUtils::nullable_autograd_meta(self->tensor);
1775+
egr::AutogradMeta* values_autograd_meta =
1776+
egr::EagerUtils::nullable_autograd_meta(transed_sub_tensor);
1777+
bool trace_backward = egr::Controller::Instance().HasGrad();
1778+
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
1779+
trace_backward, x_autograd_meta, values_autograd_meta);
1780+
// Node Declaration
1781+
std::shared_ptr<SetValueWithTensorGradNode> grad_node;
1782+
// Set grad_node before API Call
1783+
if (require_any_grad) {
1784+
paddle::Tensor transback_sub_tensor =
1785+
transpose_ad_func(transed_sub_tensor, trans_back_dim);
1786+
const auto& values_tmp =
1787+
(require_any_grad && transback_sub_tensor.is_dense_tensor() &&
1788+
!std::dynamic_pointer_cast<phi::DenseTensor>(
1789+
transback_sub_tensor.impl())
1790+
->meta()
1791+
.is_contiguous())
1792+
? paddle::Tensor(
1793+
std::make_shared<phi::DenseTensor>(
1794+
std::move(paddle::experimental::Trans2Contiguous(
1795+
*(std::dynamic_pointer_cast<phi::DenseTensor>(
1796+
transback_sub_tensor.impl()))))),
1797+
transback_sub_tensor.mutable_autograd_meta())
1798+
: transback_sub_tensor;
1799+
1800+
grad_node = std::shared_ptr<SetValueWithTensorGradNode>(
1801+
new SetValueWithTensorGradNode(1, 2)); // NOLINT
1802+
grad_node->SetAttributestarts(slice_starts);
1803+
grad_node->SetAttributeends(slice_ends);
1804+
grad_node->SetAttributesteps(slice_strides);
1805+
grad_node->SetAttributeaxes(slice_axes);
1806+
grad_node->SetAttributedecrease_axes(decrease_axis);
1807+
grad_node->SetAttributenone_axes(none_axes);
1808+
grad_node->SetTensorWrappervalues(values_tmp);
1809+
1810+
paddle::memory::LogDeviceMemoryStats(
1811+
egr::Controller::Instance().GetExpectedPlace(),
1812+
"set_value_with_tensor");
1813+
egr::EagerUtils::CheckInplace(
1814+
self->tensor, x_autograd_meta, require_any_grad);
1815+
egr::EagerUtils::PassStopGradient(false, x_autograd_meta);
1816+
// SetGradOutMeta & SetEdges
1817+
grad_node->SetGradOutMeta(self->tensor, 0);
1818+
grad_node->SetGradOutMeta(transback_sub_tensor, 1);
1819+
if (x_autograd_meta) {
1820+
egr::EagerUtils::SetOutRankWithSlot(x_autograd_meta, 0);
1821+
egr::EagerUtils::SetHistory(x_autograd_meta, grad_node);
1822+
}
1823+
grad_node->SetGradInMeta(self->tensor, 0);
1824+
}
1825+
}
1826+
if (PyCheckTensor(value_obj)) {
1827+
// pass the stop_gradient from value to tensor.
1828+
// pass stop gradient should be done after CheckInplace in
1829+
// set_value__dygraph_function.
1830+
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
1831+
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
1832+
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
1833+
}
17671834
}
17681835
}
17691836
}

paddle/fluid/pybind/slice_utils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,13 @@ static paddle::Tensor getTensorWithBasicIndexing(
347347
std::vector<int64_t>* decrease_axis,
348348
std::vector<int64_t>* none_axes,
349349
std::vector<int64_t>* infer_flags,
350-
bool* use_strided_slice) {
350+
bool* use_strided_slice,
351+
bool* out_is_view) {
351352
paddle::Tensor out;
352353
if (slice_axes->empty()) {
353354
out = tensor;
354355
} else {
356+
*out_is_view = true;
355357
if (!(*use_strided_slice)) {
356358
eager_gil_scoped_release guard;
357359
out = slice_ad_func(tensor,
@@ -372,6 +374,7 @@ static paddle::Tensor getTensorWithBasicIndexing(
372374
}
373375
}
374376
if (!none_axes->empty()) {
377+
*out_is_view = true;
375378
eager_gil_scoped_release guard;
376379
// Deal with cases that decrease_axes is not empty
377380
// For example:
@@ -400,7 +403,8 @@ static paddle::Tensor dealWithAdvancedIndex(
400403
std::vector<int>* trans_back_dim,
401404
int* pos_of_new_dim,
402405
int* rank_of_new_dim,
403-
std::vector<int>* trans_dim) {
406+
std::vector<int>* trans_dim,
407+
bool* out_is_view) {
404408
int p = 0;
405409
for (size_t i = 0; i < advanced_index_dim->size(); ++i) {
406410
auto index_dim = (*advanced_index_dim)[i];
@@ -443,6 +447,7 @@ static paddle::Tensor dealWithAdvancedIndex(
443447
if (original_dim_order == *trans_dim) {
444448
transed_tensor = tensor;
445449
} else {
450+
*out_is_view = true;
446451
transed_tensor = transpose_ad_func(tensor, *trans_dim);
447452
}
448453

0 commit comments

Comments
 (0)