19
19
#include " paddle/phi/core/dense_tensor.h"
20
20
#include " paddle/phi/core/tensor_utils.h"
21
21
#include " paddle/phi/kernels/empty_kernel.h"
22
+ #include " paddle/phi/kernels/expand_kernel.h"
22
23
#include " paddle/phi/kernels/funcs/broadcast_function.h"
23
24
#include " paddle/phi/kernels/funcs/eigen/common.h"
24
25
#include " paddle/phi/kernels/funcs/eigen/eigen_function.h"
25
26
#include " paddle/phi/kernels/funcs/elementwise_functor.h"
26
27
#include " paddle/phi/kernels/funcs/slice_utils.h"
27
-
28
28
namespace phi {
29
29
30
30
// check whether the tensor with dimension of second can assign to the
@@ -89,7 +89,6 @@ void SetValueImpl(const Context& dev_ctx,
89
89
in_dims, axes, starts_local, ends_local, &steps_local);
90
90
auto decrease_slice_dims =
91
91
phi::funcs::GetDecreasedDims (slice_dims, decrease_axes);
92
-
93
92
auto slice_dims_for_assign = decrease_slice_dims;
94
93
if (!none_axes.empty ()) {
95
94
std::vector<int64_t > slice_dims_with_none;
@@ -115,33 +114,36 @@ void SetValueImpl(const Context& dev_ctx,
115
114
116
115
slice_dims_for_assign = common::make_ddim (slice_dims_with_none);
117
116
}
117
+ CheckIsDimsMatch (slice_dims_for_assign, value.dims ());
118
+
119
+ auto value_shape = phi::vectorize<int64_t >(value.dims ());
120
+
121
+ DenseTensor value_tensor = Empty<T>(dev_ctx, IntArray{value_shape});
122
+ value_tensor = value;
123
+ auto it = value_shape.begin ();
124
+ while (it != value_shape.end () && *it == 1 ) {
125
+ it = value_shape.erase (it);
126
+ }
127
+ if (value_shape.empty ()) value_shape.push_back (1 );
128
+ value_tensor.Resize (phi::make_ddim (value_shape));
129
+
130
+ auto expand_shape = phi::vectorize<int64_t >(slice_dims_for_assign);
131
+ for (size_t i = 0 ; i <= expand_shape.size (); i++) {
132
+ if (expand_shape[i] == 0 ) expand_shape[i] = 1 ;
133
+ }
134
+ if (expand_shape.empty ()) expand_shape.push_back (1 );
135
+ DenseTensor expand_tensor = Empty<T>(dev_ctx, IntArray{expand_shape});
118
136
119
137
auto place = dev_ctx.GetPlace ();
120
138
auto & eigen_place = *dev_ctx.eigen_device ();
121
139
122
- // Here copy data from input to avoid data loss at PE and Graph level.
123
- // TODO(liym27): Speed up in the future version.
124
- // - Q: Why don't call ShareDataWith to speed up?
125
- // - A: Because it's not supported to ShareDataWith on OP's input and output
126
- // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
127
- // - Q: Why don't delete Input, after all, the input and output are the same
128
- // Tensor at program level?
129
- // - A: If deleting Input, the graph will be complex, such as there will
130
- // be two ops points to the output in graph: op1 -> output <- set_value.
131
- // In this case, we have to find a way to handle the running order of
132
- // set_value is what we want.
133
140
Copy (dev_ctx, in, place, false , out);
141
+ ExpandKernel<T, Context>(
142
+ dev_ctx, value_tensor, IntArray{expand_shape}, &expand_tensor);
143
+ expand_tensor.Resize (slice_dims);
134
144
135
- DenseTensor slice_tensor =
136
- Empty<T>(dev_ctx, IntArray{slice_dims.Get (), slice_dims.size ()});
137
- DenseTensor pad_tensor =
138
- Empty<T>(dev_ctx, IntArray{in_dims.Get (), in_dims.size ()});
139
- auto pad_e = EigenTensor<T, RANK>::From (pad_tensor, in_dims);
140
145
auto out_e = EigenTensor<T, RANK>::From (*out);
141
- auto slice_e = EigenTensor<T, RANK>::From (slice_tensor, slice_dims);
142
-
143
- // Step 1: Set the value of out at `_index` to zero
144
- slice_e.device (eigen_place) = slice_e.constant (T (0 ));
146
+ auto value_e = EigenTensor<T, RANK>::From (expand_tensor);
145
147
146
148
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
147
149
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
@@ -164,65 +166,7 @@ void SetValueImpl(const Context& dev_ctx,
164
166
}
165
167
166
168
out_e.stridedSlice (starts_indices, ends_indices, strides_indices)
167
- .device (eigen_place) = slice_e;
168
-
169
- // Step 2: Set a tensor with the same shape as out tensor. And its data at
170
- // '_index' is the same as value, and data out of '_index' to zero
171
-
172
- // - Step 2.1 Set slice tensor with value
173
-
174
- // NOTE(liym27): [ Why resize slice_tensor here? ]
175
- // A: When do broadcasting on slice_tensor and value, the shape of
176
- // slice_tensor should be decreased dims.
177
- // e.g.
178
- // x[:,0] = value
179
- // x's shape = [3, 4], value's shape = [3]
180
- // We get slice_dims = [3, 1], decrease_slice_dims = [3]
181
- // If do broadcasting on Tensor with shape [3, 1] and [3], the result's
182
- // shape is [3, 3], which cross the border;
183
- // If do broadcasting on Tensor with shape [3] and [3], the result's shape
184
- // is [3], which is right.
185
-
186
- slice_tensor.Resize (slice_dims_for_assign);
187
-
188
- CheckIsDimsMatch (slice_dims_for_assign, value.dims ());
189
-
190
- bool is_gpu_place = dev_ctx.GetPlace ().GetType () == phi::AllocationType::GPU;
191
- if (is_gpu_place || slice_tensor.dims ().size () >= value.dims ().size ()) {
192
- // [Why here we confirm running device]
193
- // ElementwiseComputeEx can do broadcasting in two cases:
194
- // 1. The place is GPU.
195
- // 2. The place is CPU, and the 'x' does not need broadcast.
196
- // Please see the note in
197
- // paddle/fluid/operators/elementwise/elementwise_op_function.h
198
- // So, here we choose different logic depending on the device to avoid
199
- // numerical problems, temporarily.
200
- //
201
- // TODO(zoooo0820): Reimplement logic of set_value to avoid using
202
- // elementwise-sub.
203
- funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
204
- dev_ctx,
205
- slice_tensor,
206
- value,
207
- funcs::SubtractFunctor<T>(),
208
- &slice_tensor);
209
- } else {
210
- funcs::ElementwiseCompute<funcs::InverseSubtractFunctor<T>, T>(
211
- dev_ctx,
212
- slice_tensor,
213
- value,
214
- funcs::InverseSubtractFunctor<T>(),
215
- &slice_tensor);
216
- }
217
- slice_tensor.Resize (slice_dims);
218
-
219
- // - Step 2.2 Pad slice tensor with 0
220
- pad_e.device (eigen_place) = pad_e.constant (T (0 ));
221
- pad_e.stridedSlice (starts_indices, ends_indices, strides_indices)
222
- .device (eigen_place) = slice_e;
223
-
224
- // Step 3: Set out tensor with value
225
- out_e.device (eigen_place) = out_e - pad_e;
169
+ .device (eigen_place) = value_e;
226
170
}
227
171
228
172
template <typename T, typename Context>
0 commit comments