@@ -77,43 +77,6 @@ static void MoveOrCopyVar(framework::Variable* dst,
77
77
}
78
78
}
79
79
80
- #ifdef PADDLE_WITH_XPU
81
- template <typename T>
82
- void XPUTensorAddFunctor (const phi::Place& place,
83
- const phi::DenseTensor& src,
84
- phi::DenseTensor* dst) {
85
- using XPUType = typename XPUTypeTrait<T>::Type;
86
- phi::XPUContext* ctx = dynamic_cast <phi::XPUContext*>(
87
- phi::DeviceContextPool::Instance ().Get (place));
88
- const XPUType* x = reinterpret_cast <const XPUType*>(src.data <T>());
89
- XPUType* y = reinterpret_cast <XPUType*>(dst->mutable_data <T>(place));
90
- int r = -1 ;
91
- int numel = static_cast <int >(src.numel ());
92
- if (std::is_same<T, double >::value) {
93
- xpu::ctx_guard RAII_GUARD (ctx->x_context ());
94
- float * x_cast_to_fp32 = RAII_GUARD.alloc <float >(numel);
95
- PADDLE_ENFORCE_XDNN_NOT_NULL (x_cast_to_fp32);
96
- float * y_cast_to_fp32 = RAII_GUARD.alloc <float >(numel);
97
- PADDLE_ENFORCE_XDNN_NOT_NULL (y_cast_to_fp32);
98
- r = xpu::cast<XPUType, float >(ctx->x_context (), x, x_cast_to_fp32, numel);
99
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " cast" );
100
- r = xpu::cast<XPUType, float >(ctx->x_context (), y, y_cast_to_fp32, numel);
101
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " cast" );
102
- r = xpu::add<float >(ctx->x_context (),
103
- x_cast_to_fp32,
104
- y_cast_to_fp32,
105
- y_cast_to_fp32,
106
- numel);
107
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " add" );
108
- r = xpu::cast<float , XPUType>(ctx->x_context (), y_cast_to_fp32, y, numel);
109
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " cast" );
110
- } else {
111
- r = xpu::add<XPUType>(ctx->x_context (), x, y, y, numel);
112
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " add" );
113
- }
114
- }
115
- #endif
116
-
117
80
template <typename TType>
118
81
TType* GetInnerMutableTensor (framework::Variable* dst) {
119
82
auto * dst_tensor = dst->GetMutable <TType>();
@@ -218,6 +181,15 @@ void TensorAdd(const VarType& src, VarType* dst) {
218
181
#endif
219
182
}
220
183
184
+ if (phi::is_xpu_place (place)) {
185
+ #if defined(PADDLE_WITH_XPU)
186
+ PADDLE_TENSOR_ADD (float , phi::XPUContext);
187
+ PADDLE_TENSOR_ADD (double , phi::XPUContext);
188
+ PADDLE_TENSOR_ADD (phi::dtype::float16, phi::XPUContext);
189
+ PADDLE_TENSOR_ADD (phi::dtype::bfloat16, phi::XPUContext);
190
+ #endif
191
+ }
192
+
221
193
#define TENSOR_ADD_EIGEN (T ) \
222
194
auto cpu_ctx = static_cast <phi::CPUContext*>( \
223
195
phi::DeviceContextPool::Instance ().Get (place)); \
@@ -264,29 +236,6 @@ void TensorAdd(const VarType& src, VarType* dst) {
264
236
#endif
265
237
}
266
238
267
- #ifdef PADDLE_WITH_XPU
268
- if (phi::is_xpu_place (place)) {
269
- if (data_type == framework::DataTypeTrait<float >::DataType ()) {
270
- XPUTensorAddFunctor<float >(place, src_tensor, dst_tensor);
271
- } else if (data_type ==
272
- framework::DataTypeTrait<phi::dtype::float16>::DataType ()) {
273
- XPUTensorAddFunctor<phi::dtype::float16>(place, src_tensor, dst_tensor);
274
- } else if (data_type == framework::DataTypeTrait<double >::DataType ()) {
275
- XPUTensorAddFunctor<double >(place, src_tensor, dst_tensor);
276
- } else if (data_type ==
277
- framework::DataTypeTrait<phi::dtype::bfloat16>::DataType ()) {
278
- XPUTensorAddFunctor<phi::dtype::bfloat16>(place, src_tensor, dst_tensor);
279
- } else {
280
- PADDLE_THROW (common::errors::Unimplemented (
281
- " Gradient accumulation of data type (%s) on place (%s) is not "
282
- " supported in imperative mode" ,
283
- framework::DataTypeToString (data_type),
284
- place));
285
- }
286
- return ;
287
- }
288
- #endif
289
-
290
239
PADDLE_THROW (common::errors::Unimplemented (
291
240
" Gradient accumulation of data type (%s) on place (%s) is not "
292
241
" supported in imperative mode" ,
0 commit comments