Skip to content

Commit 9931609

Browse files
authored
[XPU] fix xpu grad merge bug when using amp master_grad (cherry-pick from 068d95f) (#71694)
1 parent 8513fe6 commit 9931609

File tree

1 file changed

+9
-60
lines changed

1 file changed

+9
-60
lines changed

paddle/fluid/imperative/gradient_accumulator.cc

+9-60
Original file line numberDiff line numberDiff line change
@@ -77,43 +77,6 @@ static void MoveOrCopyVar(framework::Variable* dst,
7777
}
7878
}
7979

80-
#ifdef PADDLE_WITH_XPU
81-
template <typename T>
82-
void XPUTensorAddFunctor(const phi::Place& place,
83-
const phi::DenseTensor& src,
84-
phi::DenseTensor* dst) {
85-
using XPUType = typename XPUTypeTrait<T>::Type;
86-
phi::XPUContext* ctx = dynamic_cast<phi::XPUContext*>(
87-
phi::DeviceContextPool::Instance().Get(place));
88-
const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
89-
XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
90-
int r = -1;
91-
int numel = static_cast<int>(src.numel());
92-
if (std::is_same<T, double>::value) {
93-
xpu::ctx_guard RAII_GUARD(ctx->x_context());
94-
float* x_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
95-
PADDLE_ENFORCE_XDNN_NOT_NULL(x_cast_to_fp32);
96-
float* y_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
97-
PADDLE_ENFORCE_XDNN_NOT_NULL(y_cast_to_fp32);
98-
r = xpu::cast<XPUType, float>(ctx->x_context(), x, x_cast_to_fp32, numel);
99-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
100-
r = xpu::cast<XPUType, float>(ctx->x_context(), y, y_cast_to_fp32, numel);
101-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
102-
r = xpu::add<float>(ctx->x_context(),
103-
x_cast_to_fp32,
104-
y_cast_to_fp32,
105-
y_cast_to_fp32,
106-
numel);
107-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
108-
r = xpu::cast<float, XPUType>(ctx->x_context(), y_cast_to_fp32, y, numel);
109-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
110-
} else {
111-
r = xpu::add<XPUType>(ctx->x_context(), x, y, y, numel);
112-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
113-
}
114-
}
115-
#endif
116-
11780
template <typename TType>
11881
TType* GetInnerMutableTensor(framework::Variable* dst) {
11982
auto* dst_tensor = dst->GetMutable<TType>();
@@ -218,6 +181,15 @@ void TensorAdd(const VarType& src, VarType* dst) {
218181
#endif
219182
}
220183

184+
if (phi::is_xpu_place(place)) {
185+
#if defined(PADDLE_WITH_XPU)
186+
PADDLE_TENSOR_ADD(float, phi::XPUContext);
187+
PADDLE_TENSOR_ADD(double, phi::XPUContext);
188+
PADDLE_TENSOR_ADD(phi::dtype::float16, phi::XPUContext);
189+
PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::XPUContext);
190+
#endif
191+
}
192+
221193
#define TENSOR_ADD_EIGEN(T) \
222194
auto cpu_ctx = static_cast<phi::CPUContext*>( \
223195
phi::DeviceContextPool::Instance().Get(place)); \
@@ -264,29 +236,6 @@ void TensorAdd(const VarType& src, VarType* dst) {
264236
#endif
265237
}
266238

267-
#ifdef PADDLE_WITH_XPU
268-
if (phi::is_xpu_place(place)) {
269-
if (data_type == framework::DataTypeTrait<float>::DataType()) {
270-
XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
271-
} else if (data_type ==
272-
framework::DataTypeTrait<phi::dtype::float16>::DataType()) {
273-
XPUTensorAddFunctor<phi::dtype::float16>(place, src_tensor, dst_tensor);
274-
} else if (data_type == framework::DataTypeTrait<double>::DataType()) {
275-
XPUTensorAddFunctor<double>(place, src_tensor, dst_tensor);
276-
} else if (data_type ==
277-
framework::DataTypeTrait<phi::dtype::bfloat16>::DataType()) {
278-
XPUTensorAddFunctor<phi::dtype::bfloat16>(place, src_tensor, dst_tensor);
279-
} else {
280-
PADDLE_THROW(common::errors::Unimplemented(
281-
"Gradient accumulation of data type (%s) on place (%s) is not "
282-
"supported in imperative mode",
283-
framework::DataTypeToString(data_type),
284-
place));
285-
}
286-
return;
287-
}
288-
#endif
289-
290239
PADDLE_THROW(common::errors::Unimplemented(
291240
"Gradient accumulation of data type (%s) on place (%s) is not "
292241
"supported in imperative mode",

0 commit comments

Comments
 (0)