@@ -143,71 +143,26 @@ DEFINE_CPU_TRANS_NORMAL(phi::dtype::complex<float>);
143
143
DEFINE_CPU_TRANS_NORMAL (phi::dtype::complex<double >);
144
144
145
145
struct TensorSetConstantCPU {
146
- TensorSetConstantCPU (phi::DenseTensor* tensor, const void * value)
146
+ TensorSetConstantCPU (phi::DenseTensor* tensor, float value)
147
147
: tensor_(tensor), value_(value) {}
148
148
template <typename T>
149
149
void apply () const {
150
150
auto cpu = phi::CPUPlace ();
151
151
auto * begin = tensor_->mutable_data <T>(cpu);
152
- const T* num_ptr = reinterpret_cast <const T*>(value_);
153
- T num = *num_ptr;
154
- std::fill (begin, begin + tensor_->numel (), num);
152
+ std::fill (begin, begin + tensor_->numel (), static_cast <T>(value_));
155
153
}
156
154
phi::DenseTensor* tensor_;
157
- const void * value_;
155
+ float value_;
158
156
};
159
157
160
- #ifdef PADDLE_WITH_XPU
161
- struct TensorSetConstantXPU {
162
- TensorSetConstantXPU (const phi::DeviceContext& context,
163
- phi::DenseTensor* tensor,
164
- const void * value,
165
- phi::Place place)
166
- : context_(context), tensor_(tensor), value_(value), place_(place) {}
167
- template <typename T>
168
- void apply () const {
169
- auto * ctx = phi::DeviceContextPool::Instance ().Get (place_);
170
- auto data = ctx->Alloc <T>(tensor_);
171
- const T* num = reinterpret_cast <const T*>(value_);
172
- T num_value = static_cast <T>(*num);
173
- int numel = tensor_->numel ();
174
- if (((std::is_same<T, float >::value) ||
175
- (std::is_same<T, phi::dtype::float16>::value)) &&
176
- (place_ == phi::XPUPlace ())) {
177
- using XPUType = typename XPUTypeTrait<T>::Type;
178
- auto * dev_ctx = static_cast <phi::XPUContext*>(ctx);
179
- int r = xpu::constant (dev_ctx->x_context (),
180
- reinterpret_cast <XPUType*>(data),
181
- numel,
182
- static_cast <XPUType>(num_value));
183
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " constant" );
184
- dev_ctx->Wait ();
185
- } else {
186
- std::unique_ptr<T[]> data_cpu (new T[numel]);
187
- std::fill (
188
- data_cpu.get (), data_cpu.get () + numel, static_cast <T>(num_value));
189
- memory_utils::Copy (place_,
190
- data,
191
- phi::CPUPlace (),
192
- static_cast <void *>(data_cpu.get ()),
193
- numel * sizeof (T));
194
- }
195
- }
196
- const phi::DeviceContext& context_;
197
- phi::DenseTensor* tensor_;
198
- const void * value_;
199
- phi::Place place_;
200
- };
201
- #endif
202
-
203
158
template <>
204
159
void set_constant_with_place<phi::XPUPlace>(const phi::DeviceContext& context,
205
160
phi::DenseTensor* tensor,
206
- const void * value) {
161
+ float value) {
207
162
#ifdef PADDLE_WITH_XPU
208
163
phi::VisitDataType (
209
164
tensor->dtype (),
210
- TensorSetConstantXPU (context, tensor, value, tensor->place ()));
165
+ TensorSetConstantXPU< float >( tensor, value, tensor->place ()));
211
166
#else
212
167
PADDLE_THROW (phi::errors::PreconditionNotMet (" Not compiled with XPU!" ));
213
168
#endif
@@ -216,15 +171,13 @@ void set_constant_with_place<phi::XPUPlace>(const phi::DeviceContext& context,
216
171
template <>
217
172
void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
218
173
phi::DenseTensor* tensor,
219
- const void * value) {
174
+ float value) {
220
175
PADDLE_THROW (phi::errors::Unimplemented (" IPUPlace is not supported" ));
221
176
}
222
177
223
178
template <>
224
179
void set_constant_with_place<phi::CustomPlace>(
225
- const phi::DeviceContext& context,
226
- phi::DenseTensor* tensor,
227
- const void * value) {
180
+ const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
228
181
#ifdef PADDLE_WITH_CUSTOM_DEVICE
229
182
auto kernel_result = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
230
183
" full" ,
@@ -237,12 +190,10 @@ void set_constant_with_place<phi::CustomPlace>(
237
190
const phi::Scalar&,
238
191
DataType,
239
192
phi::DenseTensor*);
240
- const float * num_ptr = reinterpret_cast <const float *>(value);
241
- float num = *num_ptr;
242
193
auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
243
194
(*kernel_fn)(context,
244
195
phi::IntArray (common::vectorize (tensor->dims ())),
245
- phi::Scalar (num ),
196
+ phi::Scalar (value ),
246
197
tensor->dtype (),
247
198
tensor);
248
199
#else
@@ -253,15 +204,13 @@ void set_constant_with_place<phi::CustomPlace>(
253
204
template <>
254
205
void set_constant_with_place<phi::CPUPlace>(const phi::DeviceContext& context,
255
206
phi::DenseTensor* tensor,
256
- const void * value) {
207
+ float value) {
257
208
phi::VisitDataType (tensor->dtype (), TensorSetConstantCPU (tensor, value));
258
209
}
259
210
260
211
template <>
261
212
void set_constant_with_place<phi::GPUPinnedPlace>(
262
- const phi::DeviceContext& context,
263
- phi::DenseTensor* tensor,
264
- const void * value) {
213
+ const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
265
214
phi::VisitDataType (tensor->dtype (), TensorSetConstantCPU (tensor, value));
266
215
}
267
216
@@ -270,7 +219,7 @@ struct TensorSetConstantWithPlace {
270
219
using result_type = void ;
271
220
TensorSetConstantWithPlace (const phi::DeviceContext& context,
272
221
phi::DenseTensor* tensor,
273
- const void * value)
222
+ float value)
274
223
: context_(context), tensor_(tensor), value_(value) {}
275
224
276
225
template <typename Place>
@@ -280,12 +229,12 @@ struct TensorSetConstantWithPlace {
280
229
281
230
const phi::DeviceContext& context_;
282
231
phi::DenseTensor* tensor_;
283
- const void * value_;
232
+ float value_;
284
233
};
285
234
286
235
void set_constant (const phi::DeviceContext& context,
287
236
phi::DenseTensor* tensor,
288
- const void * value) {
237
+ float value) {
289
238
TensorSetConstantWithPlace func (context, tensor, value);
290
239
#ifdef PADDLE_WITH_CUSTOM_DEVICE
291
240
if (context.GetPlace ().GetType () == phi::AllocationType::CUSTOM) {
0 commit comments