Skip to content

Commit 0e96df8

Browse files
authored
optimize set_value (#59425)
* optimize set_value * fix none shape
1 parent 7dd9a56 commit 0e96df8

File tree

1 file changed

+25
-81
lines changed

1 file changed

+25
-81
lines changed

paddle/phi/kernels/impl/set_value_kernel_impl.h

Lines changed: 25 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
#include "paddle/phi/core/dense_tensor.h"
2020
#include "paddle/phi/core/tensor_utils.h"
2121
#include "paddle/phi/kernels/empty_kernel.h"
22+
#include "paddle/phi/kernels/expand_kernel.h"
2223
#include "paddle/phi/kernels/funcs/broadcast_function.h"
2324
#include "paddle/phi/kernels/funcs/eigen/common.h"
2425
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
2526
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
2627
#include "paddle/phi/kernels/funcs/slice_utils.h"
27-
2828
namespace phi {
2929

3030
// check whether the tensor with dimension of second can assign to the
@@ -89,7 +89,6 @@ void SetValueImpl(const Context& dev_ctx,
8989
in_dims, axes, starts_local, ends_local, &steps_local);
9090
auto decrease_slice_dims =
9191
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
92-
9392
auto slice_dims_for_assign = decrease_slice_dims;
9493
if (!none_axes.empty()) {
9594
std::vector<int64_t> slice_dims_with_none;
@@ -115,33 +114,36 @@ void SetValueImpl(const Context& dev_ctx,
115114

116115
slice_dims_for_assign = common::make_ddim(slice_dims_with_none);
117116
}
117+
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
118+
119+
auto value_shape = phi::vectorize<int64_t>(value.dims());
120+
121+
DenseTensor value_tensor = Empty<T>(dev_ctx, IntArray{value_shape});
122+
value_tensor = value;
123+
auto it = value_shape.begin();
124+
while (it != value_shape.end() && *it == 1) {
125+
it = value_shape.erase(it);
126+
}
127+
if (value_shape.empty()) value_shape.push_back(1);
128+
value_tensor.Resize(phi::make_ddim(value_shape));
129+
130+
auto expand_shape = phi::vectorize<int64_t>(slice_dims_for_assign);
131+
for (size_t i = 0; i <= expand_shape.size(); i++) {
132+
if (expand_shape[i] == 0) expand_shape[i] = 1;
133+
}
134+
if (expand_shape.empty()) expand_shape.push_back(1);
135+
DenseTensor expand_tensor = Empty<T>(dev_ctx, IntArray{expand_shape});
118136

119137
auto place = dev_ctx.GetPlace();
120138
auto& eigen_place = *dev_ctx.eigen_device();
121139

122-
// Here copy data from input to avoid data loss at PE and Graph level.
123-
// TODO(liym27): Speed up in the future version.
124-
// - Q: Why don't call ShareDataWith to speed up?
125-
// - A: Because it's not supported to ShareDataWith on OP's input and output
126-
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
127-
// - Q: Why don't delete Input, after all, the input and output are the same
128-
// Tensor at program level?
129-
// - A: If deleting Input, the graph will be complex, such as there will
130-
// be two ops points to the output in graph: op1 -> output <- set_value.
131-
// In this case, we have to find a way to handle the running order of
132-
// set_value is what we want.
133140
Copy(dev_ctx, in, place, false, out);
141+
ExpandKernel<T, Context>(
142+
dev_ctx, value_tensor, IntArray{expand_shape}, &expand_tensor);
143+
expand_tensor.Resize(slice_dims);
134144

135-
DenseTensor slice_tensor =
136-
Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()});
137-
DenseTensor pad_tensor =
138-
Empty<T>(dev_ctx, IntArray{in_dims.Get(), in_dims.size()});
139-
auto pad_e = EigenTensor<T, RANK>::From(pad_tensor, in_dims);
140145
auto out_e = EigenTensor<T, RANK>::From(*out);
141-
auto slice_e = EigenTensor<T, RANK>::From(slice_tensor, slice_dims);
142-
143-
// Step 1: Set the value of out at `_index` to zero
144-
slice_e.device(eigen_place) = slice_e.constant(T(0));
146+
auto value_e = EigenTensor<T, RANK>::From(expand_tensor);
145147

146148
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
147149
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
@@ -164,65 +166,7 @@ void SetValueImpl(const Context& dev_ctx,
164166
}
165167

166168
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
167-
.device(eigen_place) = slice_e;
168-
169-
// Step 2: Set a tensor with the same shape as out tensor. And its data at
170-
// '_index' is the same as value, and data out of '_index' to zero
171-
172-
// - Step 2.1 Set slice tensor with value
173-
174-
// NOTE(liym27): [ Why resize slice_tensor here? ]
175-
// A: When do broadcasting on slice_tensor and value, the shape of
176-
// slice_tensor should be decreased dims.
177-
// e.g.
178-
// x[:,0] = value
179-
// x's shape = [3, 4], value's shape = [3]
180-
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
181-
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
182-
// shape is [3, 3], which cross the border;
183-
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
184-
// is [3], which is right.
185-
186-
slice_tensor.Resize(slice_dims_for_assign);
187-
188-
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
189-
190-
bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
191-
if (is_gpu_place || slice_tensor.dims().size() >= value.dims().size()) {
192-
// [Why here we confirm running device]
193-
// ElementwiseComputeEx can do broadcasting in two cases:
194-
// 1. The place is GPU.
195-
// 2. The place is CPU, and the 'x' does not need broadcast.
196-
// Please see the note in
197-
// paddle/fluid/operators/elementwise/elementwise_op_function.h
198-
// So, here we choose different logic depending on the device to avoid
199-
// numerical problems, temporarily.
200-
//
201-
// TODO(zoooo0820): Reimplement logic of set_value to avoid using
202-
// elementwise-sub.
203-
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
204-
dev_ctx,
205-
slice_tensor,
206-
value,
207-
funcs::SubtractFunctor<T>(),
208-
&slice_tensor);
209-
} else {
210-
funcs::ElementwiseCompute<funcs::InverseSubtractFunctor<T>, T>(
211-
dev_ctx,
212-
slice_tensor,
213-
value,
214-
funcs::InverseSubtractFunctor<T>(),
215-
&slice_tensor);
216-
}
217-
slice_tensor.Resize(slice_dims);
218-
219-
// - Step 2.2 Pad slice tensor with 0
220-
pad_e.device(eigen_place) = pad_e.constant(T(0));
221-
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
222-
.device(eigen_place) = slice_e;
223-
224-
// Step 3: Set out tensor with value
225-
out_e.device(eigen_place) = out_e - pad_e;
169+
.device(eigen_place) = value_e;
226170
}
227171

228172
template <typename T, typename Context>

0 commit comments

Comments
 (0)