|
17 | 17 |
|
18 | 18 | namespace custom_kernel {
|
19 | 19 |
|
| 20 | +template <typename T, typename Context> |
| 21 | +void CastKernel(const Context& dev_ctx, |
| 22 | + const phi::DenseTensor& x, |
| 23 | + phi::DataType dtype, |
| 24 | + phi::DenseTensor* out); |
| 25 | + |
20 | 26 | template <typename T, typename Context>
|
21 | 27 | void AclopSoftmaxKernel(const Context& dev_ctx,
|
22 | 28 | const phi::DenseTensor& x,
|
@@ -141,7 +147,32 @@ void SoftmaxGradKernel(const Context& dev_ctx,
|
141 | 147 | dev_ctx, out, out_grad, axis, x_grad)));
|
142 | 148 | dev_ctx.template Alloc<T>(x_grad);
|
143 | 149 | int64_t dim = static_cast<int64_t>(axis);
|
144 |
| - EXEC_NPU_CMD(aclnnSoftmaxBackward, dev_ctx, out_grad, out, dim, *x_grad); |
| 150 | + |
| 151 | + phi::DenseTensor cast_x; |
| 152 | + if (out_grad.dtype() == phi::DataType::FLOAT64) { |
| 153 | + phi::DenseTensorMeta meta(out_grad.meta()); |
| 154 | + meta.dtype = phi::DataType::FLOAT32; |
| 155 | + cast_x.set_meta(meta); |
| 156 | + |
| 157 | + custom_kernel::CastKernel<T, Context>( |
| 158 | + dev_ctx, out_grad, phi::DataType::FLOAT32, &cast_x); |
| 159 | + } else { |
| 160 | + cast_x = out_grad; |
| 161 | + } |
| 162 | + |
| 163 | + phi::DenseTensor cast_y; |
| 164 | + if (out.dtype() == phi::DataType::FLOAT64) { |
| 165 | + phi::DenseTensorMeta meta(out.meta()); |
| 166 | + meta.dtype = phi::DataType::FLOAT32; |
| 167 | + cast_y.set_meta(meta); |
| 168 | + |
| 169 | + custom_kernel::CastKernel<T, Context>( |
| 170 | + dev_ctx, out, phi::DataType::FLOAT32, &cast_y); |
| 171 | + } else { |
| 172 | + cast_y = out; |
| 173 | + } |
| 174 | + |
| 175 | + EXEC_NPU_CMD(aclnnSoftmaxBackward, dev_ctx, cast_x, cast_y, dim, *x_grad); |
145 | 176 | }
|
146 | 177 |
|
147 | 178 | } // namespace custom_kernel
|
|
0 commit comments