Skip to content

Commit 286c529

Browse files
cast support -1 dim (#72283)
1 parent 48b55cd commit 286c529

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

paddle/phi/kernels/cpu/cast_kernel.cc

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ void CastKernel(const Context& dev_ctx,
2626
DataType out_dtype,
2727
DenseTensor* out) {
2828
if (x.dtype() == out_dtype) {
29+
if (x.dims() == phi::make_ddim({-1})) {
30+
*out = x;
31+
return;
32+
}
2933
if (!out->IsSharedWith(x)) {
3034
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3135
}

paddle/phi/kernels/gpu/cast_kernel.cu

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ void CastKernel(const Context& dev_ctx,
2525
DataType out_dtype,
2626
DenseTensor* out) {
2727
if (x.dtype() == out_dtype) {
28+
if (x.dims() == phi::make_ddim({-1})) {
29+
*out = x;
30+
return;
31+
}
2832
if (!out->IsSharedWith(x)) {
2933
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3034
}

paddle/phi/kernels/xpu/cast_kernel.cc

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ void CastKernel(const Context& dev_ctx,
7777
DataType out_dtype,
7878
DenseTensor* out) {
7979
if (x.dtype() == out_dtype) {
80+
if (x.dims() == phi::make_ddim({-1})) {
81+
*out = x;
82+
return;
83+
}
8084
if (!out->IsSharedWith(x)) {
8185
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
8286
}

0 commit comments

Comments
 (0)