Skip to content

Commit f08e785

Browse files
authored
fix set_constant error (#59905) (#60104)
1 parent 8189a99 commit f08e785

File tree

7 files changed

+105
-94
lines changed

7 files changed

+105
-94
lines changed

paddle/fluid/framework/hogwild_worker.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ template <typename T>
11311131
void HogwildWorker::SetZero(phi::DenseTensor *tensor,
11321132
const phi::DenseTensor &root_tensor) {
11331133
tensor->mutable_data<T>(root_tensor.dims(), place_);
1134-
phi::funcs::set_constant(*dev_ctx_, tensor, static_cast<T>(0.0));
1134+
phi::funcs::set_constant(*dev_ctx_, tensor, 0.0);
11351135
}
11361136

11371137
void HogwildWorker::BindingDataFeedMemory() {

paddle/phi/kernels/funcs/math_function.cc

+13-64
Original file line numberDiff line numberDiff line change
@@ -143,71 +143,26 @@ DEFINE_CPU_TRANS_NORMAL(phi::dtype::complex<float>);
143143
DEFINE_CPU_TRANS_NORMAL(phi::dtype::complex<double>);
144144

145145
struct TensorSetConstantCPU {
146-
TensorSetConstantCPU(phi::DenseTensor* tensor, const void* value)
146+
TensorSetConstantCPU(phi::DenseTensor* tensor, float value)
147147
: tensor_(tensor), value_(value) {}
148148
template <typename T>
149149
void apply() const {
150150
auto cpu = phi::CPUPlace();
151151
auto* begin = tensor_->mutable_data<T>(cpu);
152-
const T* num_ptr = reinterpret_cast<const T*>(value_);
153-
T num = *num_ptr;
154-
std::fill(begin, begin + tensor_->numel(), num);
152+
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
155153
}
156154
phi::DenseTensor* tensor_;
157-
const void* value_;
155+
float value_;
158156
};
159157

160-
#ifdef PADDLE_WITH_XPU
161-
struct TensorSetConstantXPU {
162-
TensorSetConstantXPU(const phi::DeviceContext& context,
163-
phi::DenseTensor* tensor,
164-
const void* value,
165-
phi::Place place)
166-
: context_(context), tensor_(tensor), value_(value), place_(place) {}
167-
template <typename T>
168-
void apply() const {
169-
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
170-
auto data = ctx->Alloc<T>(tensor_);
171-
const T* num = reinterpret_cast<const T*>(value_);
172-
T num_value = static_cast<T>(*num);
173-
int numel = tensor_->numel();
174-
if (((std::is_same<T, float>::value) ||
175-
(std::is_same<T, phi::dtype::float16>::value)) &&
176-
(place_ == phi::XPUPlace())) {
177-
using XPUType = typename XPUTypeTrait<T>::Type;
178-
auto* dev_ctx = static_cast<phi::XPUContext*>(ctx);
179-
int r = xpu::constant(dev_ctx->x_context(),
180-
reinterpret_cast<XPUType*>(data),
181-
numel,
182-
static_cast<XPUType>(num_value));
183-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
184-
dev_ctx->Wait();
185-
} else {
186-
std::unique_ptr<T[]> data_cpu(new T[numel]);
187-
std::fill(
188-
data_cpu.get(), data_cpu.get() + numel, static_cast<T>(num_value));
189-
memory_utils::Copy(place_,
190-
data,
191-
phi::CPUPlace(),
192-
static_cast<void*>(data_cpu.get()),
193-
numel * sizeof(T));
194-
}
195-
}
196-
const phi::DeviceContext& context_;
197-
phi::DenseTensor* tensor_;
198-
const void* value_;
199-
phi::Place place_;
200-
};
201-
#endif
202-
203158
template <>
204159
void set_constant_with_place<phi::XPUPlace>(const phi::DeviceContext& context,
205160
phi::DenseTensor* tensor,
206-
const void* value) {
161+
float value) {
207162
#ifdef PADDLE_WITH_XPU
208163
phi::VisitDataType(
209164
tensor->dtype(),
210-
TensorSetConstantXPU(context, tensor, value, tensor->place()));
165+
TensorSetConstantXPU<float>(tensor, value, tensor->place()));
211166
#else
212167
PADDLE_THROW(phi::errors::PreconditionNotMet("Not compiled with XPU!"));
213168
#endif
@@ -216,15 +171,13 @@ void set_constant_with_place<phi::XPUPlace>(const phi::DeviceContext& context,
216171
template <>
217172
void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
218173
phi::DenseTensor* tensor,
219-
const void* value) {
174+
float value) {
220175
PADDLE_THROW(phi::errors::Unimplemented("IPUPlace is not supported"));
221176
}
222177

223178
template <>
224179
void set_constant_with_place<phi::CustomPlace>(
225-
const phi::DeviceContext& context,
226-
phi::DenseTensor* tensor,
227-
const void* value) {
180+
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
228181
#ifdef PADDLE_WITH_CUSTOM_DEVICE
229182
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
230183
"full",
@@ -237,12 +190,10 @@ void set_constant_with_place<phi::CustomPlace>(
237190
const phi::Scalar&,
238191
DataType,
239192
phi::DenseTensor*);
240-
const float* num_ptr = reinterpret_cast<const float*>(value);
241-
float num = *num_ptr;
242193
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
243194
(*kernel_fn)(context,
244195
phi::IntArray(common::vectorize(tensor->dims())),
245-
phi::Scalar(num),
196+
phi::Scalar(value),
246197
tensor->dtype(),
247198
tensor);
248199
#else
@@ -253,15 +204,13 @@ void set_constant_with_place<phi::CustomPlace>(
253204
template <>
254205
void set_constant_with_place<phi::CPUPlace>(const phi::DeviceContext& context,
255206
phi::DenseTensor* tensor,
256-
const void* value) {
207+
float value) {
257208
phi::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
258209
}
259210

260211
template <>
261212
void set_constant_with_place<phi::GPUPinnedPlace>(
262-
const phi::DeviceContext& context,
263-
phi::DenseTensor* tensor,
264-
const void* value) {
213+
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
265214
phi::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
266215
}
267216

@@ -270,7 +219,7 @@ struct TensorSetConstantWithPlace {
270219
using result_type = void;
271220
TensorSetConstantWithPlace(const phi::DeviceContext& context,
272221
phi::DenseTensor* tensor,
273-
const void* value)
222+
float value)
274223
: context_(context), tensor_(tensor), value_(value) {}
275224

276225
template <typename Place>
@@ -280,12 +229,12 @@ struct TensorSetConstantWithPlace {
280229

281230
const phi::DeviceContext& context_;
282231
phi::DenseTensor* tensor_;
283-
const void* value_;
232+
float value_;
284233
};
285234

286235
void set_constant(const phi::DeviceContext& context,
287236
phi::DenseTensor* tensor,
288-
const void* value) {
237+
float value) {
289238
TensorSetConstantWithPlace func(context, tensor, value);
290239
#ifdef PADDLE_WITH_CUSTOM_DEVICE
291240
if (context.GetPlace().GetType() == phi::AllocationType::CUSTOM) {

paddle/phi/kernels/funcs/math_function.cu

+7-15
Original file line numberDiff line numberDiff line change
@@ -336,34 +336,26 @@ DEFINE_GPU_TRANS_NORMAL(phi::dtype::complex<double>);
336336
struct TensorSetConstantGPU {
337337
TensorSetConstantGPU(const phi::DeviceContext& context,
338338
phi::DenseTensor* tensor,
339-
const void* value)
339+
float value)
340340
: context_(context), tensor_(tensor), value_(value) {}
341341

342342
template <typename T>
343343
void apply() const {
344-
// SetConstant<phi::GPUContext, T> functor;
345-
// functor(reinterpret_cast<const phi::GPUContext&>(context_),
346-
// tensor_,
347-
// static_cast<T>(value_));
348-
int N = static_cast<int>(tensor_->numel());
349-
if (N <= 0) {
350-
return;
351-
}
352-
auto& ctx = reinterpret_cast<const phi::GPUContext&>(context_);
353-
const T* num = reinterpret_cast<const T*>(value_);
354-
FillConstantKernel<T><<<(N + 512 - 1) / 512, 512, 0, ctx.stream()>>>(
355-
N, tensor_->mutable_data<T>(ctx.GetPlace()), static_cast<T>(*num));
344+
SetConstant<phi::GPUContext, T> functor;
345+
functor(reinterpret_cast<const phi::GPUContext&>(context_),
346+
tensor_,
347+
static_cast<T>(value_));
356348
}
357349

358350
const phi::DeviceContext& context_;
359351
phi::DenseTensor* tensor_;
360-
const void* value_;
352+
float value_;
361353
};
362354

363355
template <>
364356
void set_constant_with_place<phi::GPUPlace>(const phi::DeviceContext& context,
365357
phi::DenseTensor* tensor,
366-
const void* value) {
358+
float value) {
367359
phi::VisitDataType(tensor->dtype(),
368360
TensorSetConstantGPU(context, tensor, value));
369361
}

paddle/phi/kernels/funcs/math_function.h

+70-9
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,23 @@ struct SetConstant {
6464
T num);
6565
};
6666

67+
#ifdef PADDLE_WITH_XPU
68+
template <typename T>
69+
struct SetConstant<phi::XPUContext, T> {
70+
void operator()(const phi::XPUContext& context,
71+
phi::DenseTensor* tensor,
72+
T num);
73+
};
74+
#endif
75+
6776
template <typename Place>
6877
void set_constant_with_place(const phi::DeviceContext& context,
6978
phi::DenseTensor* tensor,
70-
const void* value);
71-
72-
void set_constant(const phi::DeviceContext& context,
73-
phi::DenseTensor* tensor,
74-
const void* value);
79+
float value);
7580

76-
template <typename T>
7781
void set_constant(const phi::DeviceContext& context,
7882
phi::DenseTensor* tensor,
79-
const T value) {
80-
set_constant(context, tensor, reinterpret_cast<const void*>(&value));
81-
}
83+
float value);
8284

8385
template <typename DeviceContext, typename T>
8486
struct RowwiseAdd {
@@ -109,6 +111,65 @@ struct RowwiseMean {
109111
phi::DenseTensor* vec);
110112
};
111113

114+
#ifdef PADDLE_WITH_XPU
115+
template <typename U>
116+
struct TensorSetConstantXPU {
117+
TensorSetConstantXPU(phi::DenseTensor* tensor, U value, phi::Place place)
118+
: tensor_(tensor), value_(value), place_(place) {}
119+
template <typename T>
120+
void apply() const {
121+
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
122+
auto begin = ctx->Alloc<T>(tensor_);
123+
int numel = tensor_->numel();
124+
std::unique_ptr<T[]> data_cpu(new T[numel]);
125+
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
126+
memory_utils::Copy(place_,
127+
begin,
128+
phi::CPUPlace(),
129+
static_cast<void*>(data_cpu.get()),
130+
numel * sizeof(T));
131+
}
132+
phi::DenseTensor* tensor_;
133+
U value_;
134+
phi::Place place_;
135+
};
136+
137+
template <>
138+
struct TensorSetConstantXPU<float> {
139+
TensorSetConstantXPU(phi::DenseTensor* tensor, float value, phi::Place place)
140+
: tensor_(tensor), value_(value), place_(place) {}
141+
template <typename T>
142+
void apply() const {
143+
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
144+
auto begin = ctx->Alloc<T>(tensor_);
145+
int numel = tensor_->numel();
146+
if (((std::is_same<T, float>::value) ||
147+
(std::is_same<T, phi::dtype::float16>::value)) &&
148+
(place_ == phi::XPUPlace())) {
149+
using XPUType = typename XPUTypeTrait<T>::Type;
150+
auto* dev_ctx = static_cast<phi::XPUContext*>(ctx);
151+
int r = xpu::constant<XPUType>(dev_ctx->x_context(),
152+
reinterpret_cast<XPUType*>(begin),
153+
numel,
154+
static_cast<XPUType>(value_));
155+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
156+
dev_ctx->Wait();
157+
} else {
158+
std::unique_ptr<T[]> data_cpu(new T[numel]);
159+
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
160+
memory_utils::Copy(place_,
161+
begin,
162+
phi::CPUPlace(),
163+
static_cast<void*>(data_cpu.get()),
164+
numel * sizeof(T));
165+
}
166+
}
167+
phi::DenseTensor* tensor_;
168+
float value_;
169+
phi::Place place_;
170+
};
171+
#endif
172+
112173
template <typename Context, typename T>
113174
inline void TransCompute(const int dim,
114175
const Context& dev_ctx,

paddle/phi/kernels/funcs/math_function_impl.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,20 @@ template <typename DeviceContext, typename T>
2929
void SetConstant<DeviceContext, T>::operator()(const DeviceContext& context,
3030
phi::DenseTensor* tensor,
3131
T num) {
32-
// auto t = phi::EigenVector<T>::Flatten(*tensor);
33-
// t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
34-
set_constant(context, tensor, reinterpret_cast<const void*>(&num));
32+
auto t = phi::EigenVector<T>::Flatten(*tensor);
33+
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
3534
}
3635

36+
#ifdef PADDLE_WITH_XPU
37+
template <typename T>
38+
void SetConstant<phi::XPUContext, T>::operator()(const phi::XPUContext& context,
39+
phi::DenseTensor* tensor,
40+
T num) {
41+
phi::VisitDataType(tensor->dtype(),
42+
TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
43+
}
44+
#endif
45+
3746
template <typename DeviceContext, typename T, int Rank>
3847
void Transpose<DeviceContext, T, Rank>::operator()(
3948
const DeviceContext& context,

paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct OneHotV2OpFunctor {
3838
auto* p_in_data = in_->data<InT>();
3939
auto numel = in_->numel();
4040
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
41-
funcs::set_constant(ctx_, out_, static_cast<OutT>(0.0));
41+
funcs::set_constant(ctx_, out_, 0.0);
4242

4343
for (int i = 0; i < numel; ++i) {
4444
PADDLE_ENFORCE_GE(

paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct OneHotV2OpCUDAFunctor {
5959
auto numel = in_->numel();
6060
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
6161
auto stream = ctx_.stream();
62-
funcs::set_constant(ctx_, out_, static_cast<OutT>(0.0));
62+
funcs::set_constant(ctx_, out_, 0.0);
6363

6464
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel);
6565

0 commit comments

Comments
 (0)