Skip to content

Commit e4f5276

Browse files
committed
update
1 parent 2906aa3 commit e4f5276

File tree

2 files changed

+256
-0
lines changed

2 files changed

+256
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/svdvals_grad_kernel.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/impl/svdvals_grad_kernel_impl.h"
18+
19+
PD_REGISTER_KERNEL(
20+
svdvals_grad, GPU, ALL_LAYOUT, phi::SvdvalsGradKernel, float, double) {}
+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/dynload/cusolver.h"
16+
#include "paddle/phi/common/memory_utils.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
19+
namespace phi {
20+
template <class T>
21+
static void GesvdjBatchedSvdvals(const phi::GPUContext& dev_ctx,
22+
int batchSize,
23+
int m,
24+
int n,
25+
int k,
26+
T* A,
27+
T* U,
28+
T* V,
29+
T* S,
30+
int* info,
31+
int thin_UV = 0 // only compute UV
32+
);
33+
34+
template <>
35+
void GesvdjBatchedSvdvals<float>(const phi::GPUContext& dev_ctx,
36+
int batchSize,
37+
int m,
38+
int n,
39+
int k,
40+
float* A,
41+
float* S,
42+
int* info,
43+
int thin_UV) {
44+
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
45+
gesvdjInfo_t gesvdj_params = NULL;
46+
int lda = m;
47+
int ldu = 1;
48+
int ldv = 1;
49+
int lwork = 0;
50+
auto handle = dev_ctx.cusolver_dn_handle();
51+
PADDLE_ENFORCE_GPU_SUCCESS(
52+
phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
53+
PADDLE_ENFORCE_GPU_SUCCESS(
54+
phi::dynload::cusolverDnSgesvdj_bufferSize(handle,
55+
jobz,
56+
thin_UV,
57+
m,
58+
n,
59+
A,
60+
lda,
61+
S,
62+
nullptr,
63+
ldu,
64+
nullptr,
65+
ldv,
66+
&lwork,
67+
gesvdj_params));
68+
auto workspace = phi::memory_utils::Alloc(
69+
dev_ctx.GetPlace(),
70+
lwork * sizeof(float),
71+
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
72+
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
73+
int stride_A = lda * n;
74+
for (int i = 0; i < batchSize; ++i) {
75+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgesvdj(handle,
76+
jobz,
77+
thin_UV,
78+
m,
79+
n,
80+
A + stride_A * i,
81+
lda,
82+
S + k * i,
83+
nullptr,
84+
ldu,
85+
nullptr,
86+
ldv,
87+
workspace_ptr,
88+
lwork,
89+
info,
90+
gesvdj_params));
91+
// check the error info
92+
int error_info;
93+
memory_utils::Copy(phi::CPUPlace(),
94+
&error_info,
95+
dev_ctx.GetPlace(),
96+
info,
97+
sizeof(int),
98+
dev_ctx.stream());
99+
PADDLE_ENFORCE_EQ(
100+
error_info,
101+
0,
102+
common::errors::PreconditionNotMet(
103+
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
104+
}
105+
PADDLE_ENFORCE_GPU_SUCCESS(
106+
phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
107+
}
108+
109+
template <>
110+
void GesvdjBatchedSvdvals<double>(const phi::GPUContext& dev_ctx,
111+
int batchSize,
112+
int m,
113+
int n,
114+
int k,
115+
double* A,
116+
double* S,
117+
int* info,
118+
int thin_UV) {
119+
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
120+
gesvdjInfo_t gesvdj_params = NULL;
121+
int lda = m;
122+
int ldu = 1;
123+
int ldv = 1;
124+
int lwork = 0;
125+
auto handle = dev_ctx.cusolver_dn_handle();
126+
PADDLE_ENFORCE_GPU_SUCCESS(
127+
phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
128+
PADDLE_ENFORCE_GPU_SUCCESS(
129+
phi::dynload::cusolverDnDgesvdj_bufferSize(handle,
130+
jobz,
131+
thin_UV,
132+
m,
133+
n,
134+
A,
135+
lda,
136+
S,
137+
nullptr,
138+
ldu,
139+
nullptr,
140+
ldv,
141+
&lwork,
142+
gesvdj_params));
143+
auto workspace = phi::memory_utils::Alloc(
144+
dev_ctx.GetPlace(),
145+
lwork * sizeof(double),
146+
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
147+
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
148+
int stride_A = lda * n;
149+
for (int i = 0; i < batchSize; ++i) {
150+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgesvdj(handle,
151+
jobz,
152+
thin_UV,
153+
m,
154+
n,
155+
A + stride_A * i,
156+
lda,
157+
S + k * i,
158+
nullptr,
159+
ldu,
160+
nullptr,
161+
ldv,
162+
workspace_ptr,
163+
lwork,
164+
info,
165+
gesvdj_params));
166+
// check the error info
167+
int error_info;
168+
memory_utils::Copy(phi::CPUPlace(),
169+
&error_info,
170+
dev_ctx.GetPlace(),
171+
info,
172+
sizeof(int),
173+
dev_ctx.stream());
174+
PADDLE_ENFORCE_EQ(
175+
error_info,
176+
0,
177+
common::errors::PreconditionNotMet(
178+
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
179+
}
180+
PADDLE_ENFORCE_GPU_SUCCESS(
181+
phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
182+
}
183+
184+
template <typename T, typename Context>
185+
void SvdvalsKernel(const Context& dev_ctx,
186+
const DenseTensor& X,
187+
DenseTensor* S) {
188+
auto& dims = X.dims();
189+
int rows = static_cast<int>(dims[dims.size() - 2]);
190+
int cols = static_cast<int>(dims[dims.size() - 1]);
191+
int k = std::min(rows, cols);
192+
int batches = static_cast<int>(X.numel() / (rows * cols));
193+
PADDLE_ENFORCE_GT(
194+
rows,
195+
0,
196+
common::errors::InvalidArgument("Rows of X must be greater than 0."));
197+
PADDLE_ENFORCE_GT(
198+
cols,
199+
0,
200+
common::errors::InvalidArgument("Cols of X must be greater than 0."));
201+
PADDLE_ENFORCE_GT(
202+
batches,
203+
0,
204+
common::errors::InvalidArgument("Batch size must be greater than 0."));
205+
206+
auto* S_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
207+
DDim S_dims;
208+
if (dims.size() <= 2) {
209+
S_dims = {k};
210+
} else {
211+
S_dims = {batches, k};
212+
}
213+
S->Resize(S_dims);
214+
auto* S_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
215+
216+
auto info = Empty<int, Context>(dev_ctx, {batches});
217+
int* info_ptr = reinterpret_cast<int*>(info.data());
218+
219+
DenseTensor x_tmp;
220+
Copy(dev_ctx, X, dev_ctx.GetPlace(), false, &x_tmp);
221+
222+
GesvdjBatchedSvdvals<T>(dev_ctx,
223+
batches,
224+
rows,
225+
cols,
226+
k,
227+
dev_ctx.template Alloc<T>(&x_tmp),
228+
S_out,
229+
info_ptr,
230+
0);
231+
}
232+
233+
} // namespace phi
234+
235+
PD_REGISTER_KERNEL(
236+
svdvals, GPU, ALL_LAYOUT, phi::SvdvalsKernel, float, double) {}

0 commit comments

Comments
 (0)