Skip to content

Commit 17cb8cc

Browse files
committed
fix
1 parent ce9ad8a commit 17cb8cc

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

paddle/phi/kernels/gpu/slogdeterminant_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ __global__ void GetSlogDetFromLUComplex(const Complex_T* lu_data,
186186
T epsilon = std::numeric_limits<T>::epsilon();
187187

188188
if (abs_det <= epsilon) {
189-
sign[idx] = Complex_T(1.0, 0.0);
189+
sign[idx] = Complex_T(0.0, 0.0);
190190
logdet[idx] = -std::numeric_limits<T>::infinity();
191191
} else {
192192
Complex_T abs_det_complex = static_cast<Complex_T>(abs_det);

paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,16 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
105105
VLOG(2) << "matrix val: " << matrix;
106106
std::complex<T> det_val = matrix.determinant();
107107
T abs_det_val = std::abs(det_val);
108-
sign_data[i] = static_cast<Complex_T>(
109-
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val)));
110-
logdet_data[i] = std::log(abs_det_val);
108+
T epsilon = std::numeric_limits<T>::epsilon();
109+
110+
if (abs_det_val <= epsilon) {
111+
sign_data[i] = Complex_T(0.0, 0.0);
112+
logdet_data[i] = -std::numeric_limits<T>::infinity();
113+
} else {
114+
sign_data[i] = static_cast<Complex_T>(
115+
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val)));
116+
logdet_data[i] = std::log(abs_det_val);
117+
}
111118
}
112119
}
113120
};

0 commit comments

Comments
 (0)