@@ -54,6 +54,7 @@ typedef SSIZE_T ssize_t;
54
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
55
55
#include " paddle/common/ddim.h"
56
56
#include " paddle/fluid/eager/amp_utils.h"
57
+ #include " paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
57
58
#include " paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
58
59
#include " paddle/fluid/eager/eager_amp_auto_cast.h"
59
60
#include " paddle/fluid/framework/python_headers.h"
@@ -1359,6 +1360,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1359
1360
&use_strided_slice);
1360
1361
1361
1362
// step2: Dealing with basic indexing
1363
+ bool out_is_view = false ;
1362
1364
auto out = getTensorWithBasicIndexing (tensor,
1363
1365
&slice_axes,
1364
1366
&slice_starts,
@@ -1367,7 +1369,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1367
1369
&decrease_axis,
1368
1370
&none_axes,
1369
1371
&infer_flags,
1370
- &use_strided_slice);
1372
+ &use_strided_slice,
1373
+ &out_is_view);
1371
1374
1372
1375
if (!has_advanced_index) {
1373
1376
return ToPyObject (out);
@@ -1386,7 +1389,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1386
1389
&trans_back_dim,
1387
1390
&pos_of_new_dim,
1388
1391
&rank_of_new_dim,
1389
- &trans_dim);
1392
+ &trans_dim,
1393
+ &out_is_view);
1390
1394
1391
1395
if (transed_index.size () == 1 &&
1392
1396
transed_index[0 ].dtype () == phi::DataType::BOOL) {
@@ -1416,14 +1420,14 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1416
1420
1417
1421
if (pos_of_new_dim != 0 ) {
1418
1422
std::vector<int > perm (out.shape ().size (), 0 );
1419
- int tmp1 = pos_of_new_dim , tmp2 = 0 ,
1423
+ int tmp1 = rank_of_new_dim , tmp2 = 0 ,
1420
1424
tmp3 = pos_of_new_dim + rank_of_new_dim;
1421
1425
for (int i = 0 ; i < static_cast <int >(out.shape ().size ()); ++i) {
1422
- if (i < rank_of_new_dim ) {
1426
+ if (i < pos_of_new_dim ) {
1423
1427
perm[i] =
1424
- tmp1++; // range(pos_of_new_dim , pos_of_new_dim + rank_of_new_dim)
1425
- } else if (i >= rank_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1426
- perm[i] = tmp2++; // range(0, pos_of_new_dim )
1428
+ tmp1++; // range(rank_of_new_dim , pos_of_new_dim + rank_of_new_dim)
1429
+ } else if (i >= pos_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1430
+ perm[i] = tmp2++; // range(0, rank_of_new_dim )
1427
1431
} else {
1428
1432
perm[i] = tmp3++; // range(pos_of_new_dim + rank_of_new_dim, out.ndim)
1429
1433
}
@@ -1681,6 +1685,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1681
1685
// 3. assign values to the sliced result by index_put OP;
1682
1686
// 4. transpose back and assign the result to original tensor by set_value
1683
1687
// OP.
1688
+ bool out_is_view = false ;
1684
1689
paddle::Tensor sub_tensor = getTensorWithBasicIndexing (tensor,
1685
1690
&slice_axes,
1686
1691
&slice_starts,
@@ -1689,7 +1694,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1689
1694
&decrease_axis,
1690
1695
&none_axes,
1691
1696
&infer_flags,
1692
- &use_strided_slice);
1697
+ &use_strided_slice,
1698
+ &out_is_view);
1693
1699
1694
1700
std::vector<paddle::Tensor> transed_index;
1695
1701
std::vector<int > trans_back_dim, trans_dim;
@@ -1705,65 +1711,126 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1705
1711
&trans_back_dim,
1706
1712
&pos_of_new_dim,
1707
1713
&rank_of_new_dim,
1708
- &trans_dim);
1714
+ &trans_dim,
1715
+ &out_is_view);
1709
1716
1710
1717
// Release gil and do tracing
1711
1718
py::gil_scoped_release release;
1712
-
1713
- if (value_tensor.initialized () &&
1714
- (self->tensor .dtype () != value_tensor.dtype ())) {
1715
- if (egr::Controller::Instance ().GetAMPLevel () !=
1716
- paddle::imperative::AmpLevel::O0) {
1717
- paddle::small_vector<std::vector<paddle::Tensor>,
1718
- egr::kSlotSmallVectorSize >
1719
- tmps = {{self->tensor }, {value_tensor}};
1720
- auto amp_dtype = egr::GetAmpDestDtype (" index_put" , tmps);
1721
- self->tensor = egr::EagerAmpAutoCast (
1722
- self->tensor .name (), self->tensor , amp_dtype, " index_put" );
1723
- value_tensor = egr::EagerAmpAutoCast (
1724
- value_tensor.name (), value_tensor, amp_dtype, " index_put" );
1725
- }
1719
+ if (value_tensor.initialized ()) {
1726
1720
if (self->tensor .dtype () != value_tensor.dtype ()) {
1727
- value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1721
+ if (egr::Controller::Instance ().GetAMPLevel () !=
1722
+ paddle::imperative::AmpLevel::O0) {
1723
+ paddle::small_vector<std::vector<paddle::Tensor>,
1724
+ egr::kSlotSmallVectorSize >
1725
+ tmps = {{self->tensor }, {value_tensor}};
1726
+ auto amp_dtype = egr::GetAmpDestDtype (" index_put" , tmps);
1727
+ self->tensor = egr::EagerAmpAutoCast (
1728
+ self->tensor .name (), self->tensor , amp_dtype, " index_put" );
1729
+ value_tensor = egr::EagerAmpAutoCast (
1730
+ value_tensor.name (), value_tensor, amp_dtype, " index_put" );
1731
+ }
1732
+ if (self->tensor .dtype () != value_tensor.dtype ()) {
1733
+ value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1734
+ }
1728
1735
}
1729
- }
1730
1736
1731
- if (value_tensor.dims ().size () > 1 && pos_of_new_dim != 0 ) {
1732
- value_tensor = transpose_ad_func (value_tensor, trans_dim);
1733
- }
1737
+ if (value_tensor.dims ().size () > 1 && pos_of_new_dim != 0 ) {
1738
+ value_tensor = transpose_ad_func (value_tensor, trans_dim);
1739
+ }
1734
1740
1735
- // TODO(zoooo0820) 1.Using inplace version index_put
1736
- // 2.Remove following code after backward bug fixed.
1737
- transed_sub_tensor = assign_ad_func (transed_sub_tensor);
1741
+ const phi::distributed::ProcessMesh* mesh = nullptr ;
1742
+ if (InputsContainDistTensor (
1743
+ &mesh, self->tensor , transed_sub_tensor, value_tensor)) {
1744
+ ConvertAllInputsToDistTensor (
1745
+ mesh, self->tensor , transed_sub_tensor, value_tensor);
1746
+ }
1738
1747
1739
- const phi::distributed::ProcessMesh* mesh = nullptr ;
1740
- if (InputsContainDistTensor (
1741
- &mesh, self->tensor , transed_sub_tensor, value_tensor)) {
1742
- ConvertAllInputsToDistTensor (
1743
- mesh, self->tensor , transed_sub_tensor, value_tensor);
1744
- }
1748
+ if (transed_index.size () == 1 &&
1749
+ transed_index[0 ].dtype () == phi::DataType::BOOL &&
1750
+ transed_index[0 ].shape ().size () == self->tensor .shape ().size ()) {
1751
+ if (value_tensor.shape () != self->tensor .shape ()) {
1752
+ value_tensor = expand_ad_func (value_tensor, self->tensor .shape ());
1753
+ }
1754
+ transed_sub_tensor =
1755
+ where__ad_func (logical_not_ad_func (transed_index[0 ]),
1756
+ transed_sub_tensor,
1757
+ value_tensor);
1758
+ } else {
1759
+ transed_sub_tensor =
1760
+ index_put__ad_func (transed_sub_tensor, transed_index, value_tensor);
1761
+ }
1745
1762
1746
- transed_sub_tensor =
1747
- index_put_ad_func (transed_sub_tensor, transed_index, value_tensor);
1748
-
1749
- paddle::Tensor transback_sub_tensor =
1750
- transpose_ad_func (transed_sub_tensor, trans_back_dim);
1751
-
1752
- self->tensor = set_value_with_tensor__ad_func (self->tensor ,
1753
- transback_sub_tensor,
1754
- slice_starts,
1755
- slice_ends,
1756
- slice_strides,
1757
- slice_axes,
1758
- decrease_axis,
1759
- none_axes);
1760
- if (PyCheckTensor (value_obj)) {
1761
- // pass the stop_gradient from value to tensor.
1762
- // pass stop gradient should be done after CheckInplace in
1763
- // set_value__dygraph_function.
1764
- if (!egr::EagerUtils::autograd_meta (&value_tensor)->StopGradient () &&
1765
- egr::EagerUtils::autograd_meta (&self->tensor )->StopGradient ()) {
1766
- egr::EagerUtils::autograd_meta (&self->tensor )->SetStopGradient (false );
1763
+ if (out_is_view) {
1764
+ // NOTE(zoooo0820): if out_is_view is true, it is a case of
1765
+ // combined-indexing setitem, i.e. firstly we get a view of
1766
+ // self->tensor, then modified it with inplace api index_put_ For now,
1767
+ // in design of Paddle, the forward result is right. But the backward
1768
+ // edge can not be established because the Base Tensor cannot sense
1769
+ // whether it has been modified by other operations. Following codes are
1770
+ // to add a new node (set_value_with_tensor_grad) to record the backward
1771
+ // edge, with out ad_function which needs to do the forward calculation.
1772
+
1773
+ egr::AutogradMeta* x_autograd_meta =
1774
+ egr::EagerUtils::nullable_autograd_meta (self->tensor );
1775
+ egr::AutogradMeta* values_autograd_meta =
1776
+ egr::EagerUtils::nullable_autograd_meta (transed_sub_tensor);
1777
+ bool trace_backward = egr::Controller::Instance ().HasGrad ();
1778
+ bool require_any_grad = egr::EagerUtils::ComputeRequireGrad (
1779
+ trace_backward, x_autograd_meta, values_autograd_meta);
1780
+ // Node Declaration
1781
+ std::shared_ptr<SetValueWithTensorGradNode> grad_node;
1782
+ // Set grad_node before API Call
1783
+ if (require_any_grad) {
1784
+ paddle::Tensor transback_sub_tensor =
1785
+ transpose_ad_func (transed_sub_tensor, trans_back_dim);
1786
+ const auto & values_tmp =
1787
+ (require_any_grad && transback_sub_tensor.is_dense_tensor () &&
1788
+ !std::dynamic_pointer_cast<phi::DenseTensor>(
1789
+ transback_sub_tensor.impl ())
1790
+ ->meta ()
1791
+ .is_contiguous ())
1792
+ ? paddle::Tensor (
1793
+ std::make_shared<phi::DenseTensor>(
1794
+ std::move (paddle::experimental::Trans2Contiguous (
1795
+ *(std::dynamic_pointer_cast<phi::DenseTensor>(
1796
+ transback_sub_tensor.impl ()))))),
1797
+ transback_sub_tensor.mutable_autograd_meta ())
1798
+ : transback_sub_tensor;
1799
+
1800
+ grad_node = std::shared_ptr<SetValueWithTensorGradNode>(
1801
+ new SetValueWithTensorGradNode (1 , 2 )); // NOLINT
1802
+ grad_node->SetAttributestarts (slice_starts);
1803
+ grad_node->SetAttributeends (slice_ends);
1804
+ grad_node->SetAttributesteps (slice_strides);
1805
+ grad_node->SetAttributeaxes (slice_axes);
1806
+ grad_node->SetAttributedecrease_axes (decrease_axis);
1807
+ grad_node->SetAttributenone_axes (none_axes);
1808
+ grad_node->SetTensorWrappervalues (values_tmp);
1809
+
1810
+ paddle::memory::LogDeviceMemoryStats (
1811
+ egr::Controller::Instance ().GetExpectedPlace (),
1812
+ " set_value_with_tensor" );
1813
+ egr::EagerUtils::CheckInplace (
1814
+ self->tensor , x_autograd_meta, require_any_grad);
1815
+ egr::EagerUtils::PassStopGradient (false , x_autograd_meta);
1816
+ // SetGradOutMeta & SetEdges
1817
+ grad_node->SetGradOutMeta (self->tensor , 0 );
1818
+ grad_node->SetGradOutMeta (transback_sub_tensor, 1 );
1819
+ if (x_autograd_meta) {
1820
+ egr::EagerUtils::SetOutRankWithSlot (x_autograd_meta, 0 );
1821
+ egr::EagerUtils::SetHistory (x_autograd_meta, grad_node);
1822
+ }
1823
+ grad_node->SetGradInMeta (self->tensor , 0 );
1824
+ }
1825
+ }
1826
+ if (PyCheckTensor (value_obj)) {
1827
+ // pass the stop_gradient from value to tensor.
1828
+ // pass stop gradient should be done after CheckInplace in
1829
+ // set_value__dygraph_function.
1830
+ if (!egr::EagerUtils::autograd_meta (&value_tensor)->StopGradient () &&
1831
+ egr::EagerUtils::autograd_meta (&self->tensor )->StopGradient ()) {
1832
+ egr::EagerUtils::autograd_meta (&self->tensor )->SetStopGradient (false );
1833
+ }
1767
1834
}
1768
1835
}
1769
1836
}
0 commit comments