Skip to content

Commit 7671706

Browse files
edge case: tril!, triu! and I constructors with empty matrices (#642)
Co-authored-by: Christian Guinard <28689358+christiangnrd@users.noreply.github.com>
1 parent 8d569fc commit 7671706

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

src/host/construction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ end
3737
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
3838
res = similar(T, dims)
3939
fill!(res, zero(U))
40+
isempty(res) && return res
4041
kernel = identity_kernel(get_backend(res))
4142
kernel(res, size(res, 1), s.λ; ndrange=minimum(dims))
42-
res
43+
return res
4344
end
4445

4546
(T::Type{<: AnyGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
@@ -48,9 +49,10 @@ end
4849

4950
function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
5051
fill!(A, zero(T))
52+
isempty(A) && return A
5153
kernel = identity_kernel(get_backend(A))
5254
kernel(A, size(A, 1), s.λ; ndrange=minimum(size(A)))
53-
A
55+
return A
5456
end
5557

5658
function _one(unit::T, x::AbstractGPUMatrix) where {T}

src/host/linalg.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -169,27 +169,29 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
169169
end
170170

171171
function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
172-
@kernel function tril_kernel!(_A, _d)
173-
I = @index(Global, Cartesian)
174-
i, j = Tuple(I)
175-
if i < j - _d
176-
@inbounds _A[i, j] = zero(T)
172+
isempty(A) && return A
173+
@kernel function tril_kernel!(_A, _d)
174+
I = @index(Global, Cartesian)
175+
i, j = Tuple(I)
176+
if i < j - _d
177+
@inbounds _A[i, j] = zero(T)
178+
end
177179
end
178-
end
179-
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
180-
return A
180+
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
181+
return A
181182
end
182183

183184
function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
184-
@kernel function triu_kernel!(_A, _d)
185-
I = @index(Global, Cartesian)
186-
i, j = Tuple(I)
187-
if j < i + _d
188-
@inbounds _A[i, j] = zero(T)
185+
isempty(A) && return A
186+
@kernel function triu_kernel!(_A, _d)
187+
I = @index(Global, Cartesian)
188+
i, j = Tuple(I)
189+
if j < i + _d
190+
@inbounds _A[i, j] = zero(T)
191+
end
189192
end
190-
end
191-
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
192-
return A
193+
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
194+
return A
193195
end
194196

195197
# check if upper triangular starting from the kth superdiagonal.

test/testsuite/construction.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,16 @@
188188
@test Array(x1) x
189189
end
190190

191+
@testset "empty" begin
192+
x = Matrix{Float32}(I, (0, 3))
193+
x1 = AT{Float32, 2}(I, (0, 3))
194+
195+
@test Array(x1) x
196+
197+
copyto!(x1, I)
198+
@test Array(x1) x
199+
end
200+
191201
@testset "JuliaGPU/GPUArrays.jl#439" begin
192202
x = AT{Float32}(I, 500, 300)
193203
y = Array{Float32}(I, 500, 300)

test/testsuite/linalg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,13 @@
311311
@test_throws SingularException ldiv!(D, B)
312312
end
313313

314-
@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
314+
@testset "$f with diagonal $d" for f in (triu, triu!, tril, tril!),
315315
d in -2:2
316316
A = randn(Float32, 10, 10)
317-
@test f(A, d) == Array(f!(AT(A), d))
317+
@test compare(f, AT, A, d)
318+
319+
A_empty = randn(Float32, 0, 0)
320+
@test compare(f, AT, A_empty, d)
318321
end
319322
end
320323

0 commit comments

Comments
 (0)