Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ end
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
res = similar(T, dims)
fill!(res, zero(U))
isempty(res) && return res
kernel = identity_kernel(get_backend(res))
kernel(res, size(res, 1), s.λ; ndrange=minimum(dims))
res
return res
end

(T::Type{<: AnyGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
Expand All @@ -48,9 +49,10 @@ end

function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
fill!(A, zero(T))
isempty(A) && return A
kernel = identity_kernel(get_backend(A))
kernel(A, size(A, 1), s.λ; ndrange=minimum(size(A)))
A
return A
end

function _one(unit::T, x::AbstractGPUMatrix) where {T}
Expand Down
34 changes: 18 additions & 16 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,29 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
end

function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@kernel function tril_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if i < j - _d
@inbounds _A[i, j] = zero(T)
isempty(A) && return A
@kernel function tril_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if i < j - _d
@inbounds _A[i, j] = zero(T)
end
end
end
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
end

function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@kernel function triu_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if j < i + _d
@inbounds _A[i, j] = zero(T)
isempty(A) && return A
@kernel function triu_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if j < i + _d
@inbounds _A[i, j] = zero(T)
end
end
end
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
end

# check if upper triangular starting from the kth superdiagonal.
Expand Down
10 changes: 10 additions & 0 deletions test/testsuite/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@
@test Array(x1) ≈ x
end

@testset "empty" begin
x = Matrix{Float32}(I, (0, 3))
x1 = AT{Float32, 2}(I, (0, 3))

@test Array(x1) ≈ x

copyto!(x1, I)
@test Array(x1) ≈ x
end

@testset "JuliaGPU/GPUArrays.jl#439" begin
x = AT{Float32}(I, 500, 300)
y = Array{Float32}(I, 500, 300)
Expand Down
7 changes: 5 additions & 2 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,13 @@
@test_throws SingularException ldiv!(D, B)
end

@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
@testset "$f with diagonal $d" for f in (triu, triu!, tril, tril!),
d in -2:2
A = randn(Float32, 10, 10)
@test f(A, d) == Array(f!(AT(A), d))
@test compare(f, AT, A)

A_empty = randn(Float32, 0, 0)
@test compare(f, AT, A_empty)
end
end

Expand Down
Loading