Skip to content

Commit 464727d

Browse files
authored
[Kernel]Adapt squeeze/unsqeeze Kernel to remove xshape (PaddlePaddle#1373)
1 parent b5a4e77 commit 464727d

File tree

8 files changed

+129
-139
lines changed

8 files changed

+129
-139
lines changed

backends/gcu/kernels/squeeze_kernel.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
namespace custom_kernel {
1919

2020
template <typename T, typename Context>
21-
void SqueezeInferKernel(const Context& dev_ctx,
22-
const phi::DenseTensor& x,
23-
const phi::IntArray& axes_int_array,
24-
phi::DenseTensor* out) {
21+
void SqueezeKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
const phi::IntArray& axes_int_array,
24+
phi::DenseTensor* out) {
2525
PADDLE_GCU_KERNEL_TRACE("squeeze_infer");
2626
VLOG(6) << "[HOST_KERNEL] Impl on host for squeeze_infer";
2727
auto out_dims = out->dims();
@@ -33,28 +33,25 @@ void SqueezeInferKernel(const Context& dev_ctx,
3333
}
3434

3535
template <typename T, typename Context>
36-
void SqueezeKernel(const Context& dev_ctx,
37-
const phi::DenseTensor& x,
38-
const phi::IntArray& axes_int_array,
39-
phi::DenseTensor* out,
40-
phi::DenseTensor* xshape) {
36+
void SqueezeWithXShapeKernel(const Context& dev_ctx,
37+
const phi::DenseTensor& x,
38+
const phi::IntArray& axes_int_array,
39+
phi::DenseTensor* out,
40+
phi::DenseTensor* xshape) {
4141
PADDLE_GCU_KERNEL_TRACE("squeeze");
4242
VLOG(6) << "[HOST_KERNEL] Impl on host for squeeze";
43-
custom_kernel::SqueezeInferKernel<T, Context>(
44-
dev_ctx, x, axes_int_array, out);
43+
custom_kernel::SqueezeKernel<T, Context>(dev_ctx, x, axes_int_array, out);
4544
}
4645

4746
template <typename T, typename Context>
4847
void SqueezeGradKernel(const Context& dev_ctx,
49-
const phi::DenseTensor& xshape,
48+
const phi::DenseTensor& x,
5049
const phi::DenseTensor& dout,
5150
const phi::IntArray& axes_int_array,
5251
phi::DenseTensor* dx) {
5352
PADDLE_GCU_KERNEL_TRACE("squeeze_grad");
5453
VLOG(6) << "[HOST_KERNEL] Impl on host for squeeze_grad";
55-
auto xshape_dims = xshape.dims();
56-
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
57-
dx->Resize(x_dims);
54+
auto x_dims = dx->dims();
5855
dev_ctx.template Alloc<T>(dx);
5956

6057
TensorCopy(dev_ctx, dout, false, dx);
@@ -63,10 +60,10 @@ void SqueezeGradKernel(const Context& dev_ctx,
6360

6461
} // namespace custom_kernel
6562

66-
PD_REGISTER_PLUGIN_KERNEL(squeeze_infer,
63+
PD_REGISTER_PLUGIN_KERNEL(squeeze,
6764
gcu,
6865
ALL_LAYOUT,
69-
custom_kernel::SqueezeInferKernel,
66+
custom_kernel::SqueezeKernel,
7067
bool,
7168
int,
7269
uint8_t,
@@ -76,10 +73,10 @@ PD_REGISTER_PLUGIN_KERNEL(squeeze_infer,
7673
phi::dtype::float16,
7774
double) {}
7875

79-
PD_REGISTER_PLUGIN_KERNEL(squeeze,
76+
PD_REGISTER_PLUGIN_KERNEL(squeeze_with_xshape,
8077
gcu,
8178
ALL_LAYOUT,
82-
custom_kernel::SqueezeKernel,
79+
custom_kernel::SqueezeWithXShapeKernel,
8380
bool,
8481
int,
8582
uint8_t,

backends/gcu/kernels/unsqueeze_kernel.cc

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ inline phi::DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
6767
}
6868

6969
template <typename T, typename Context>
70-
void UnsqueezeInferKernel(const Context& dev_ctx,
71-
const phi::DenseTensor& x,
72-
const phi::IntArray& axes,
73-
phi::DenseTensor* out) {
70+
void UnsqueezeKernel(const Context& dev_ctx,
71+
const phi::DenseTensor& x,
72+
const phi::IntArray& axes,
73+
phi::DenseTensor* out) {
7474
PADDLE_GCU_KERNEL_TRACE("unsqueeze_infer");
7575
VLOG(6) << "[HOST_KERNEL] Impl on host for unsqueeze_infer";
7676
auto x_dims = x.dims();
@@ -86,37 +86,35 @@ void UnsqueezeInferKernel(const Context& dev_ctx,
8686
}
8787

8888
template <typename T, typename Context>
89-
void UnsqueezeKernel(const Context& dev_ctx,
90-
const phi::DenseTensor& x,
91-
const phi::IntArray& axes,
92-
phi::DenseTensor* out,
93-
phi::DenseTensor* xshape) {
89+
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
90+
const phi::DenseTensor& x,
91+
const phi::IntArray& axes,
92+
phi::DenseTensor* out,
93+
phi::DenseTensor* xshape) {
9494
PADDLE_GCU_KERNEL_TRACE("unsqueeze");
9595
VLOG(6) << "[HOST_KERNEL] Impl on host for unsqueeze";
96-
custom_kernel::UnsqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
96+
custom_kernel::UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
9797
}
9898

9999
template <typename T, typename Context>
100100
void UnsqueezeGradKernel(const Context& dev_ctx,
101-
const phi::DenseTensor& x_shape,
101+
const phi::DenseTensor& x,
102102
const phi::DenseTensor& dout,
103103
phi::DenseTensor* dx) {
104104
PADDLE_GCU_KERNEL_TRACE("unsqueeze_grad");
105105
VLOG(6) << "[HOST_KERNEL] Impl on host for unsqueeze_grad";
106-
auto xshape_dims = x_shape.dims();
107-
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
108-
106+
auto x_dims = dx->dims();
109107
dev_ctx.template Alloc<T>(dx);
110108
custom_kernel::TensorCopy(dev_ctx, dout, false, dx);
111109
dx->Resize(x_dims);
112110
}
113111

114112
} // namespace custom_kernel
115113

116-
PD_REGISTER_PLUGIN_KERNEL(unsqueeze_infer,
114+
PD_REGISTER_PLUGIN_KERNEL(unsqueeze,
117115
gcu,
118116
ALL_LAYOUT,
119-
custom_kernel::UnsqueezeInferKernel,
117+
custom_kernel::UnsqueezeKernel,
120118
float,
121119
double,
122120
phi::dtype::bfloat16,
@@ -129,10 +127,10 @@ PD_REGISTER_PLUGIN_KERNEL(unsqueeze_infer,
129127
phi::dtype::complex<float>,
130128
phi::dtype::complex<double>) {}
131129

132-
PD_REGISTER_PLUGIN_KERNEL(unsqueeze,
130+
PD_REGISTER_PLUGIN_KERNEL(unsqueeze_with_xshape,
133131
gcu,
134132
ALL_LAYOUT,
135-
custom_kernel::UnsqueezeKernel,
133+
custom_kernel::UnsqueezeWithXShapeKernel,
136134
float,
137135
double,
138136
phi::dtype::bfloat16,

backends/mps/kernels/squeeze_kernel.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ std::vector<int64_t> GetOutputShape(const std::vector<int> squeeze_dims,
6464
}
6565

6666
template <typename T>
67-
void SqueezeInferKernel(const phi::Context& dev_ctx,
68-
const phi::DenseTensor& x,
69-
const phi::IntArray& axes_int_array,
70-
phi::DenseTensor* out) {
67+
void SqueezeKernel(const phi::Context& dev_ctx,
68+
const phi::DenseTensor& x,
69+
const phi::IntArray& axes_int_array,
70+
phi::DenseTensor* out) {
7171
auto stream = dev_ctx.stream();
7272
std::vector<int32_t> axes(axes_int_array.GetData().begin(),
7373
axes_int_array.GetData().end());
@@ -83,18 +83,21 @@ void SqueezeInferKernel(const phi::Context& dev_ctx,
8383
}
8484

8585
template <typename T>
86-
void SqueezeKernel(const phi::Context& dev_ctx,
87-
const phi::DenseTensor& x,
88-
const phi::IntArray& axes_int_array,
89-
phi::DenseTensor* out,
90-
phi::DenseTensor* xshape) {
91-
custom_kernel::SqueezeInferKernel<T>(dev_ctx, x, axes_int_array, out);
86+
void SqueezeWithXShapeKernel(const phi::Context& dev_ctx,
87+
const phi::DenseTensor& x,
88+
const phi::IntArray& axes_int_array,
89+
phi::DenseTensor* out,
90+
phi::DenseTensor* xshape) {
91+
custom_kernel::SqueezeKernel<T>(dev_ctx, x, axes_int_array, out);
9292
}
9393

9494
} // namespace custom_kernel
9595

96-
PD_BUILD_PHI_KERNEL(
97-
squeeze_infer, mps, ALL_LAYOUT, custom_kernel::SqueezeInferKernel, float) {}
98-
9996
PD_BUILD_PHI_KERNEL(
10097
squeeze, mps, ALL_LAYOUT, custom_kernel::SqueezeKernel, float) {}
98+
99+
PD_BUILD_PHI_KERNEL(squeeze_with_xshape,
100+
mps,
101+
ALL_LAYOUT,
102+
custom_kernel::SqueezeWithXShapeKernel,
103+
float) {}

backends/mps/kernels/unsqueeze_kernel.cc

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ inline std::vector<int64_t> GetUnsqueezeShape(
5252
}
5353

5454
template <typename T>
55-
void UnsqueezeInferKernel(const phi::Context& dev_ctx,
56-
const phi::DenseTensor& x,
57-
const phi::IntArray& axes,
58-
phi::DenseTensor* out) {
55+
void UnsqueezeKernel(const phi::Context& dev_ctx,
56+
const phi::DenseTensor& x,
57+
const phi::IntArray& axes,
58+
phi::DenseTensor* out) {
5959
auto x_dims = x.dims();
6060
auto out_dims = out->dims();
6161

@@ -76,21 +76,20 @@ void UnsqueezeInferKernel(const phi::Context& dev_ctx,
7676
}
7777

7878
template <typename T>
79-
void UnsqueezeKernel(const phi::Context& dev_ctx,
80-
const phi::DenseTensor& x,
81-
const phi::IntArray& axes,
82-
phi::DenseTensor* out,
83-
phi::DenseTensor* xshape) {
84-
custom_kernel::UnsqueezeInferKernel<T>(dev_ctx, x, axes, out);
79+
void UnsqueezeWithXShapeKernel(const phi::Context& dev_ctx,
80+
const phi::DenseTensor& x,
81+
const phi::IntArray& axes,
82+
phi::DenseTensor* out,
83+
phi::DenseTensor* xshape) {
84+
custom_kernel::UnsqueezeKernel<T>(dev_ctx, x, axes, out);
8585
}
8686

8787
template <typename T>
8888
void UnsqueezeGradKernel(const phi::Context& dev_ctx,
89-
const phi::DenseTensor& x_shape,
89+
const phi::DenseTensor& x,
9090
const phi::DenseTensor& dout,
9191
phi::DenseTensor* dx) {
92-
auto xshape_dims = x_shape.dims();
93-
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
92+
auto x_dims = dx->dims();
9493

9594
dev_ctx.template Alloc<T>(dx);
9695
auto dout_data = dout.data<T>();
@@ -101,15 +100,15 @@ void UnsqueezeGradKernel(const phi::Context& dev_ctx,
101100

102101
} // namespace custom_kernel
103102

104-
PD_BUILD_PHI_KERNEL(unsqueeze_infer,
103+
PD_BUILD_PHI_KERNEL(
104+
unsqueeze, mps, ALL_LAYOUT, custom_kernel::UnsqueezeKernel, float) {}
105+
106+
PD_BUILD_PHI_KERNEL(unsqueeze_with_xshape,
105107
mps,
106108
ALL_LAYOUT,
107-
custom_kernel::UnsqueezeInferKernel,
109+
custom_kernel::UnsqueezeWithXShapeKernel,
108110
float) {}
109111

110-
PD_BUILD_PHI_KERNEL(
111-
unsqueeze, mps, ALL_LAYOUT, custom_kernel::UnsqueezeKernel, float) {}
112-
113112
PD_BUILD_PHI_KERNEL(unsqueeze_grad,
114113
mps,
115114
ALL_LAYOUT,

backends/npu/kernels/squeeze_kernel.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
namespace custom_kernel {
1919

2020
template <typename T, typename Context>
21-
void SqueezeInferKernel(const Context& dev_ctx,
22-
const phi::DenseTensor& x,
23-
const phi::IntArray& axes_int_array,
24-
phi::DenseTensor* out) {
21+
void SqueezeKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
const phi::IntArray& axes_int_array,
24+
phi::DenseTensor* out) {
2525
auto stream = dev_ctx.stream();
2626

2727
auto out_dims = out->dims();
@@ -33,36 +33,33 @@ void SqueezeInferKernel(const Context& dev_ctx,
3333
}
3434

3535
template <typename T, typename Context>
36-
void SqueezeKernel(const Context& dev_ctx,
37-
const phi::DenseTensor& x,
38-
const phi::IntArray& axes_int_array,
39-
phi::DenseTensor* out,
40-
phi::DenseTensor* xshape) {
41-
custom_kernel::SqueezeInferKernel<T, Context>(
42-
dev_ctx, x, axes_int_array, out);
36+
void SqueezeWithXShapeKernel(const Context& dev_ctx,
37+
const phi::DenseTensor& x,
38+
const phi::IntArray& axes_int_array,
39+
phi::DenseTensor* out,
40+
phi::DenseTensor* xshape) {
41+
custom_kernel::SqueezeKernel<T, Context>(dev_ctx, x, axes_int_array, out);
4342
}
4443

4544
template <typename T, typename Context>
4645
void SqueezeGradKernel(const Context& dev_ctx,
47-
const phi::DenseTensor& xshape,
46+
const phi::DenseTensor& x,
4847
const phi::DenseTensor& dout,
4948
const phi::IntArray& axes_int_array,
5049
phi::DenseTensor* dx) {
5150
auto stream = dev_ctx.stream();
5251

53-
auto xshape_dims = xshape.dims();
54-
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
55-
52+
auto x_dims = dx->dims();
5653
TensorCopy(dev_ctx, dout, false, dx);
5754
dx->Resize(x_dims);
5855
}
5956

6057
} // namespace custom_kernel
6158

62-
PD_REGISTER_PLUGIN_KERNEL(squeeze_infer,
59+
PD_REGISTER_PLUGIN_KERNEL(squeeze,
6360
npu,
6461
ALL_LAYOUT,
65-
custom_kernel::SqueezeInferKernel,
62+
custom_kernel::SqueezeKernel,
6663
bool,
6764
int,
6865
uint8_t,
@@ -73,10 +70,10 @@ PD_REGISTER_PLUGIN_KERNEL(squeeze_infer,
7370
phi::dtype::bfloat16,
7471
double) {}
7572

76-
PD_REGISTER_PLUGIN_KERNEL(squeeze,
73+
PD_REGISTER_PLUGIN_KERNEL(squeeze_with_xshape,
7774
npu,
7875
ALL_LAYOUT,
79-
custom_kernel::SqueezeKernel,
76+
custom_kernel::SqueezeWithXShapeKernel,
8077
bool,
8178
int,
8279
uint8_t,

0 commit comments

Comments
 (0)