Skip to content

Commit 2e983fe

Browse files
authored
Fixes zero-dim matmatmul & matvecmul (#2958)
* workaround rmul!( , true/false) * fix matmatmul * type generic tests * also matvecmul
1 parent b30cae9 commit 2e983fe

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

lib/cublas/linalg.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ end
99
# BLAS 1
1010
#
1111

12+
function LinearAlgebra.rmul!(x::CuArray{<:CublasFloat}, k::Bool)
13+
# explicitly fill x with zero to comply with julias "false = strong zero"
14+
!k && fill!(x, zero(eltype(x)))
15+
return x
16+
end
17+
1218
LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) =
1319
scal!(length(x), k, x)
1420

@@ -267,7 +273,7 @@ function LinearAlgebra.generic_matvecmul!(Y::StridedCuVector, tA::AbstractChar,
267273
end
268274

269275
if nA == 0
270-
return rmul!(Y, 0)
276+
return rmul!(Y, beta)
271277
end
272278

273279
T = eltype(Y)
@@ -356,7 +362,7 @@ function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::Strid
356362
if size(C) != (mA, nB)
357363
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
358364
end
359-
return LinearAlgebra.rmul!(C, 0)
365+
return LinearAlgebra.rmul!(C, beta)
360366
end
361367

362368
if all(in(('N', 'T', 'C')), (tA, tB))

test/libraries/cublas/level1.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ k = 13
4141
@test dz z
4242
end
4343

44+
@testset "rmul! strong zero" begin
45+
@test testf(rmul!, fill(T(NaN), 3), false)
46+
@test testf(rmul!, rand(T, 3), false)
47+
@test testf(rmul!, rand(T, 3), true)
48+
end
49+
4450
@testset "rotate!" begin
4551
@test testf(rotate!, rand(T, m), rand(T, m), rand(real(T)), rand(real(T)))
4652
@test testf(rotate!, rand(T, m), rand(T, m), rand(real(T)), rand(T))

0 commit comments

Comments
 (0)