From 1cb1abbcabdaea8bd254228f36e84a7c1c66d480 Mon Sep 17 00:00:00 2001 From: Steven Hahn Date: Tue, 28 Oct 2025 13:27:30 -0400 Subject: [PATCH 1/3] syevBatched! interface accepts 3D CuArray Signed-off-by: Steven Hahn --- lib/cusolver/dense_generic.jl | 37 ++++++++++++++++++++++++ test/libraries/cusolver/dense_generic.jl | 31 ++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index e7f5c27134..aa8c8a15b1 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -505,6 +505,43 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla end # XsyevBatched +function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T,3}) where {T <: BlasFloat} + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) + chkuplo(uplo) + n = checksquare(A) + batch_size = size(A,3) + R = real(T) + lda = max(1, stride(A,2)) + W = CuMatrix{R}(undef, n, batch_size) + params = CuSolverParameters() + dh = dense_handle() + resize!(dh.info, batch_size) + + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n, + T, A, lda, R, W, T, out_gpu, out_cpu, batch_size) + out_gpu[], out_cpu[] + end + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A, + lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size) + end + + info = @allowscalar collect(dh.info) + for i = 1:batch_size + chkargsok(info[i] |> BlasInt) + end + + if jobz == 'N' + return W + elseif jobz == 'V' + return W, A + end +end + function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) chkuplo(uplo) diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index bca2cae46a..2600d137ea 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -33,6 +33,36 @@ p = 5 end @testset "syevBatched!" begin + batch_size = 5 + for uplo in ('L', 'U') + (CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue + + A = rand(elty, n, n, batch_size) + B = rand(elty, n, n, batch_size) + for i = 1:batch_size + S = rand(elty,n,n) + S = S * S' + I + B[:,:,i] .= S + S = uplo == 'L' ? tril(S) : triu(S) + A[:,:,i] .= S + end + d_A = CuArray(A) + d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A) + W = collect(d_W) + V = collect(d_V) + for i = 1:batch_size + Bᵢ = B[:,:,i] + Wᵢ = Diagonal(W[:,i]) + Vᵢ = V[:,:,i] + @test Bᵢ * Vᵢ ≈ Vᵢ * Diagonal(Wᵢ) + end + + d_A = CuArray(A) + d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) + end + end + + @testset "syevBatched! updated" begin batch_size = 5 for uplo in ('L', 'U') (CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue @@ -61,6 +91,7 @@ p = 5 d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) end end + end if CUSOLVER.version() >= v"11.6.0" From e3dbe9bb1d2be6de6fc3f22bda64df2e7d027861 Mon Sep 17 00:00:00 2001 From: Steven Hahn Date: Tue, 28 Oct 2025 14:55:10 -0400 Subject: [PATCH 2/3] formatting recommendations Signed-off-by: Steven Hahn --- lib/cusolver/dense_generic.jl | 24 ++++++++++++++---------- test/libraries/cusolver/dense_generic.jl | 16 ++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index aa8c8a15b1..866b282103 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -505,13 +505,13 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla end # XsyevBatched -function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T,3}) where {T <: BlasFloat} +function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T, 3}) where {T <: BlasFloat} CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) chkuplo(uplo) n = checksquare(A) - batch_size = size(A,3) + batch_size = size(A, 3) R = real(T) - lda = max(1, stride(A,2)) + lda = max(1, stride(A, 2)) W = CuMatrix{R}(undef, n, batch_size) params = CuSolverParameters() dh = dense_handle() @@ -520,18 +520,22 @@ function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T,3}) where {T function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n, - T, A, lda, R, W, T, out_gpu, out_cpu, batch_size) - out_gpu[], out_cpu[] + cusolverDnXsyevBatched_bufferSize( + dh, params, jobz, uplo, n, + T, A, lda, R, W, T, out_gpu, out_cpu, batch_size + ) + return out_gpu[], out_cpu[] end with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A, - lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size) + cusolverDnXsyevBatched( + dh, params, jobz, uplo, n, T, A, + lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size + ) end info = @allowscalar collect(dh.info) - for i = 1:batch_size + for i in 1:batch_size chkargsok(info[i] |> BlasInt) end diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index 2600d137ea..c67dbe5693 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -39,21 +39,21 @@ p = 5 A = rand(elty, n, n, batch_size) B = rand(elty, n, n, batch_size) - for i = 1:batch_size - S = rand(elty,n,n) + for i in 1:batch_size + S = rand(elty, n, n) S = S * S' + I - B[:,:,i] .= S + B[:, :, i] .= S S = uplo == 'L' ? tril(S) : triu(S) - A[:,:,i] .= S + A[:, :, i] .= S end d_A = CuArray(A) d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A) W = collect(d_W) V = collect(d_V) - for i = 1:batch_size - Bᵢ = B[:,:,i] - Wᵢ = Diagonal(W[:,i]) - Vᵢ = V[:,:,i] + for i in 1:batch_size + Bᵢ = B[:, :, i] + Wᵢ = Diagonal(W[:, i]) + Vᵢ = V[:, :, i] @test Bᵢ * Vᵢ ≈ Vᵢ * Diagonal(Wᵢ) end From 09d90258ba97b9915b126eb6670a7526859e01de Mon Sep 17 00:00:00 2001 From: Steven Hahn Date: Fri, 14 Nov 2025 09:23:58 -0500 Subject: [PATCH 3/3] update version check Signed-off-by: Steven Hahn --- lib/cusolver/dense_generic.jl | 8 ++++++-- test/libraries/cusolver/dense_generic.jl | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index 866b282103..bb7260c49a 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -506,7 +506,9 @@ end # XsyevBatched function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T, 3}) where {T <: BlasFloat} - CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) + minimum_version = v"11.7.1" + CUSOLVER.version() < minimum_version && throw(ErrorException("This operation requires cuSOLVER + $(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER.version()).")) chkuplo(uplo) n = checksquare(A) batch_size = size(A, 3) @@ -547,7 +549,9 @@ function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T, 3}) where {T end function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} - CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) + minimum_version = v"11.7.1" + CUSOLVER.version() < minimum_version && throw(ErrorException("This operation requires cuSOLVER + $(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER.version()).")) chkuplo(uplo) n, num_matrices = size(A) batch_size = num_matrices ÷ n diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index c67dbe5693..ec63ca82fb 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -91,7 +91,6 @@ p = 5 d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) end end - end if CUSOLVER.version() >= v"11.6.0"