Skip to content

Commit 252b746

Browse files
authored
【complex op No.26】add complex support for inv (PaddlePaddle#63229)
* add complex dtype for inv * fix * update * fix * fix
1 parent 86d347b commit 252b746

File tree

13 files changed

+289
-32
lines changed

13 files changed

+289
-32
lines changed

paddle/fluid/platform/dynload/cublas.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,14 @@ namespace dynload {
8080
__macro(cublasSgetriBatched); \
8181
__macro(cublasDgetrfBatched); \
8282
__macro(cublasDgetriBatched); \
83+
__macro(cublasCgetrfBatched); \
84+
__macro(cublasCgetriBatched); \
85+
__macro(cublasZgetrfBatched); \
86+
__macro(cublasZgetriBatched); \
8387
__macro(cublasSmatinvBatched); \
8488
__macro(cublasDmatinvBatched); \
89+
__macro(cublasCmatinvBatched); \
90+
__macro(cublasZmatinvBatched); \
8591
__macro(cublasSgetrsBatched); \
8692
__macro(cublasDgetrsBatched);
8793

paddle/phi/backends/dynload/cublas.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,14 @@ extern void *cublas_dso_handle;
9494
__macro(cublasSgetriBatched); \
9595
__macro(cublasDgetrfBatched); \
9696
__macro(cublasDgetriBatched); \
97+
__macro(cublasCgetrfBatched); \
98+
__macro(cublasCgetriBatched); \
99+
__macro(cublasZgetrfBatched); \
100+
__macro(cublasZgetriBatched); \
97101
__macro(cublasSmatinvBatched); \
98102
__macro(cublasDmatinvBatched); \
103+
__macro(cublasCmatinvBatched); \
104+
__macro(cublasZmatinvBatched); \
99105
__macro(cublasSgetrsBatched); \
100106
__macro(cublasDgetrsBatched);
101107

paddle/phi/kernels/cpu/inverse_grad_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,11 @@
1616

1717
#include "paddle/phi/core/kernel_registry.h"
1818

19-
PD_REGISTER_KERNEL(
20-
inverse_grad, CPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
19+
PD_REGISTER_KERNEL(inverse_grad,
20+
CPU,
21+
ALL_LAYOUT,
22+
phi::InverseGradKernel,
23+
float,
24+
double,
25+
phi::dtype::complex<float>,
26+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/inverse_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,11 @@
1616

1717
#include "paddle/phi/core/kernel_registry.h"
1818

19-
PD_REGISTER_KERNEL(
20-
inverse, CPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
19+
PD_REGISTER_KERNEL(inverse,
20+
CPU,
21+
ALL_LAYOUT,
22+
phi::InverseKernel,
23+
float,
24+
double,
25+
phi::dtype::complex<float>,
26+
phi::dtype::complex<double>) {}

paddle/phi/kernels/funcs/blas/blas_impl.cu.h

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,63 @@ struct CUBlas<phi::dtype::complex<float>> {
685685
ldb,
686686
batch_size));
687687
}
688+
689+
static void GETRF_BATCH(cublasHandle_t handle,
690+
int n,
691+
phi::dtype::complex<float> **A,
692+
int lda,
693+
int *ipiv,
694+
int *info,
695+
int batch_size) {
696+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched(
697+
handle,
698+
n,
699+
reinterpret_cast<cuFloatComplex **>(A),
700+
lda,
701+
ipiv,
702+
info,
703+
batch_size));
704+
}
705+
706+
static void GETRI_BATCH(cublasHandle_t handle,
707+
int n,
708+
const phi::dtype::complex<float> **A,
709+
int lda,
710+
const int *ipiv,
711+
phi::dtype::complex<float> **Ainv,
712+
int ldc,
713+
int *info,
714+
int batch_size) {
715+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched(
716+
handle,
717+
n,
718+
reinterpret_cast<const cuFloatComplex **>(A),
719+
lda,
720+
ipiv,
721+
reinterpret_cast<cuFloatComplex **>(Ainv),
722+
ldc,
723+
info,
724+
batch_size));
725+
}
726+
727+
static void MATINV_BATCH(cublasHandle_t handle,
728+
int n,
729+
const phi::dtype::complex<float> **A,
730+
int lda,
731+
phi::dtype::complex<float> **Ainv,
732+
int lda_inv,
733+
int *info,
734+
int batch_size) {
735+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched(
736+
handle,
737+
n,
738+
reinterpret_cast<const cuFloatComplex **>(A),
739+
lda,
740+
reinterpret_cast<cuFloatComplex **>(Ainv),
741+
lda_inv,
742+
info,
743+
batch_size));
744+
}
688745
};
689746

690747
template <>
@@ -923,6 +980,63 @@ struct CUBlas<phi::dtype::complex<double>> {
923980
"cublasGemmEx is not supported on cuda <= 7.5"));
924981
#endif
925982
}
983+
984+
static void GETRF_BATCH(cublasHandle_t handle,
985+
int n,
986+
phi::dtype::complex<double> **A,
987+
int lda,
988+
int *ipiv,
989+
int *info,
990+
int batch_size) {
991+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched(
992+
handle,
993+
n,
994+
reinterpret_cast<cuDoubleComplex **>(A),
995+
lda,
996+
ipiv,
997+
info,
998+
batch_size));
999+
}
1000+
1001+
static void GETRI_BATCH(cublasHandle_t handle,
1002+
int n,
1003+
const phi::dtype::complex<double> **A,
1004+
int lda,
1005+
const int *ipiv,
1006+
phi::dtype::complex<double> **Ainv,
1007+
int ldc,
1008+
int *info,
1009+
int batch_size) {
1010+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched(
1011+
handle,
1012+
n,
1013+
reinterpret_cast<const cuDoubleComplex **>(A),
1014+
lda,
1015+
ipiv,
1016+
reinterpret_cast<cuDoubleComplex **>(Ainv),
1017+
ldc,
1018+
info,
1019+
batch_size));
1020+
}
1021+
1022+
static void MATINV_BATCH(cublasHandle_t handle,
1023+
int n,
1024+
const phi::dtype::complex<double> **A,
1025+
int lda,
1026+
phi::dtype::complex<double> **Ainv,
1027+
int lda_inv,
1028+
int *info,
1029+
int batch_size) {
1030+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched(
1031+
handle,
1032+
n,
1033+
reinterpret_cast<const cuDoubleComplex **>(A),
1034+
lda,
1035+
reinterpret_cast<cuDoubleComplex **>(Ainv),
1036+
lda_inv,
1037+
info,
1038+
batch_size));
1039+
}
9261040
};
9271041

9281042
template <>

paddle/phi/kernels/funcs/matrix_inverse.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,
2828

2929
template class MatrixInverseFunctor<CPUContext, float>;
3030
template class MatrixInverseFunctor<CPUContext, double>;
31+
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<float>>;
32+
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<double>>;
3133

3234
} // namespace funcs
3335
} // namespace phi

paddle/phi/kernels/funcs/matrix_inverse.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,
131131

132132
template class MatrixInverseFunctor<GPUContext, float>;
133133
template class MatrixInverseFunctor<GPUContext, double>;
134+
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<float>>;
135+
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<double>>;
134136

135137
} // namespace funcs
136138
} // namespace phi

paddle/phi/kernels/funcs/matrix_inverse.h

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,69 @@ limitations under the License. */
2525
namespace phi {
2626
namespace funcs {
2727

28+
template <typename Context, typename T>
29+
struct MapMatrixInverseFunctor {
30+
void operator()(
31+
const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
32+
using Matrix =
33+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34+
using EigenMatrixMap = Eigen::Map<Matrix>;
35+
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
36+
37+
ConstEigenMatrixMap mat(a_ptr + offset, n, n);
38+
EigenMatrixMap mat_inv(a_inv_ptr + offset, n, n);
39+
Eigen::PartialPivLU<Matrix> lu;
40+
lu.compute(mat);
41+
42+
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
43+
PADDLE_ENFORCE_GT(min_abs_pivot,
44+
static_cast<T>(0),
45+
errors::InvalidArgument("Input is not invertible."));
46+
mat_inv.noalias() = lu.inverse();
47+
}
48+
};
49+
50+
template <typename Context, typename T>
51+
struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
52+
void operator()(const Context& dev_ctx,
53+
const phi::dtype::complex<T>* a_ptr,
54+
phi::dtype::complex<T>* a_inv_ptr,
55+
int offset,
56+
int n) {
57+
using Matrix = Eigen::Matrix<std::complex<T>,
58+
Eigen::Dynamic,
59+
Eigen::Dynamic,
60+
Eigen::RowMajor>;
61+
using EigenMatrixMap = Eigen::Map<Matrix>;
62+
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
63+
std::complex<T>* std_ptr = new std::complex<T>[n * n];
64+
std::complex<T>* std_inv_ptr = new std::complex<T>[n * n];
65+
for (int i = 0; i < n * n; i++) {
66+
*(std_ptr + i) = static_cast<std::complex<T>>(*(a_ptr + offset + i));
67+
}
68+
ConstEigenMatrixMap mat(std_ptr, n, n);
69+
EigenMatrixMap mat_inv(std_inv_ptr, n, n);
70+
Eigen::PartialPivLU<Matrix> lu;
71+
lu.compute(mat);
72+
73+
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
74+
PADDLE_ENFORCE_NE(min_abs_pivot,
75+
static_cast<std::complex<T>>(0),
76+
errors::InvalidArgument("Input is not invertible."));
77+
mat_inv.noalias() = lu.inverse();
78+
for (int i = 0; i < n * n; i++) {
79+
*(a_inv_ptr + offset + i) =
80+
static_cast<phi::dtype::complex<T>>(*(std_inv_ptr + i));
81+
}
82+
delete[] std_ptr;
83+
delete[] std_inv_ptr;
84+
}
85+
};
86+
2887
template <typename Context, typename T>
2988
void ComputeInverseEigen(const Context& dev_ctx,
3089
const DenseTensor& a,
3190
DenseTensor* a_inv) {
32-
using Matrix =
33-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34-
using EigenMatrixMap = Eigen::Map<Matrix>;
35-
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
3691
const auto& mat_dims = a.dims();
3792
const int rank = mat_dims.size();
3893
int n = mat_dims[rank - 1];
@@ -41,17 +96,13 @@ void ComputeInverseEigen(const Context& dev_ctx,
4196
const T* a_ptr = a.data<T>();
4297
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);
4398

99+
// Putting phi::dtype::complex into eigen::matrix has a problem,
100+
// it's not going to get the right result,
101+
// so we're going to convert it to std::complex and
102+
// then we're going to put it into eigen::matrix.
44103
for (int i = 0; i < batch_size; ++i) {
45-
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
46-
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
47-
Eigen::PartialPivLU<Matrix> lu;
48-
lu.compute(mat);
49-
50-
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
51-
PADDLE_ENFORCE_GT(min_abs_pivot,
52-
static_cast<T>(0),
53-
errors::InvalidArgument("Input is not invertible."));
54-
mat_inv.noalias() = lu.inverse();
104+
MapMatrixInverseFunctor<Context, T> functor;
105+
functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
55106
}
56107
}
57108

paddle/phi/kernels/gpu/inverse_grad_kernel.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h"
2020

21-
PD_REGISTER_KERNEL(
22-
inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
21+
PD_REGISTER_KERNEL(inverse_grad,
22+
GPU,
23+
ALL_LAYOUT,
24+
phi::InverseGradKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/inverse_kernel.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/inverse_kernel_impl.h"
2020

21-
PD_REGISTER_KERNEL(
22-
inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
21+
PD_REGISTER_KERNEL(inverse,
22+
GPU,
23+
ALL_LAYOUT,
24+
phi::InverseKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

0 commit comments

Comments
 (0)