Skip to content

Commit b30cae9

Browse files
authored
Fix complex CSC * dense vec (#2957)
1 parent 18bbb5d commit b30cae9

File tree

2 files changed

+27
-37
lines changed

2 files changed

+27
-37
lines changed

lib/cusparse/generic.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,8 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},C
158158
# Support transa = 'C' for real matrices
159159
transa = T <: Real && transa == 'C' ? 'T' : transa
160160

161-
if isa(A, CuSparseMatrixCSC)
162-
# cusparseSpMV completely supports CSC matrices with CUSPARSE.version() ≥ v"12.0".
163-
# We use Aᵀ to model them as CSR matrices for older versions of CUSPARSE.
164-
descA = CuSparseMatrixDescriptor(A, index, transposed=true)
165-
n,m = size(A)
166-
transa = transa == 'N' ? 'T' : 'N'
167-
else
168-
descA = CuSparseMatrixDescriptor(A, index)
169-
m,n = size(A)
170-
end
161+
descA = CuSparseMatrixDescriptor(A, index)
162+
m,n = size(A)
171163

172164
if transa == 'N'
173165
chkmvdims(X,n,Y,m)

test/libraries/cusparse/interfaces.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -137,35 +137,33 @@ nB = 2
137137
@test collect(dC) C
138138
end
139139
end
140-
if !(SparseMatrixType == CuSparseMatrixCSC && elty <: Complex && opa == adjoint)
141-
@testset "A * CuVector" begin
142-
@testset "A * b" begin
143-
c = opa(geam_A) * b_vec
144-
dc = opa(d_geam_A) * db_vec
145-
@test c collect(dc)
146-
end
147-
@testset "mul!(c, A, b)" begin
148-
c = rand(elty, n)
149-
dc = CuArray(c)
140+
@testset "A * CuVector" begin
141+
@testset "A * b" begin
142+
c = opa(geam_A) * b_vec
143+
dc = opa(d_geam_A) * db_vec
144+
@test c collect(dc)
145+
end
146+
@testset "mul!(c, A, b)" begin
147+
c = rand(elty, n)
148+
dc = CuArray(c)
150149

151-
mul!(c, opa(geam_A), b_vec, alpha, beta)
152-
mul!(dc, opa(d_geam_A), db_vec, alpha, beta)
153-
@test c collect(dc)
154-
end
150+
mul!(c, opa(geam_A), b_vec, alpha, beta)
151+
mul!(dc, opa(d_geam_A), db_vec, alpha, beta)
152+
@test c collect(dc)
153+
end
154+
end
155+
@testset "A * CuSparseVector" begin
156+
@testset "A * b" begin
157+
c = opa(geam_A) * b_spvec
158+
dc = opa(d_geam_A) * db_spvec
159+
@test c collect(dc)
155160
end
156-
@testset "A * CuSparseVector" begin
157-
@testset "A * b" begin
158-
c = opa(geam_A) * b_spvec
159-
dc = opa(d_geam_A) * db_spvec
160-
@test c collect(dc)
161-
end
162-
@testset "mul!(c, A, b)" begin
163-
c = rand(elty, n)
164-
dc = CuArray(c)
165-
mul!(c, opa(geam_A), b_spvec, alpha, beta)
166-
mul!(dc, opa(d_geam_A), db_spvec, alpha, beta)
167-
@test c collect(dc)
168-
end
161+
@testset "mul!(c, A, b)" begin
162+
c = rand(elty, n)
163+
dc = CuArray(c)
164+
mul!(c, opa(geam_A), b_spvec, alpha, beta)
165+
mul!(dc, opa(d_geam_A), db_spvec, alpha, beta)
166+
@test c collect(dc)
169167
end
170168
end
171169
end

0 commit comments

Comments
 (0)