Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "1.1.0"
version = "1.1.1"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down Expand Up @@ -39,9 +39,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"]
2 changes: 2 additions & 0 deletions ext/ForwardDiffStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig
vector_mode_jacobian, vector_mode_jacobian!, valtype, value
using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult

ForwardDiff.supports_fast_scalar_indexing(::StaticArray) = true

@generated function dualize(::Type{T}, x::StaticArray) where T
N = length(x)
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
Expand Down
1 change: 1 addition & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("derivative.jl")
include("gradient.jl")
include("jacobian.jl")
include("hessian.jl")
include("utils.jl")

export DiffResults

Expand Down
92 changes: 56 additions & 36 deletions src/apiutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,46 @@ end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
if isbitstype(V)
for idx in structural_eachindex(duals, x)
duals[idx] = Dual{T,V,N}(x[idx], seed)
end
else
for idx in structural_eachindex(duals, x)
if isassigned(x, idx)
if supports_fast_scalar_indexing(duals)
if isbitstype(V)
for idx in structural_eachindex(duals, x)
duals[idx] = Dual{T,V,N}(x[idx], seed)
else
Base._unsetindex!(duals, idx)
end
else
for idx in structural_eachindex(duals, x)
if isassigned(x, idx)
duals[idx] = Dual{T,V,N}(x[idx], seed)
else
Base._unsetindex!(duals, idx)
end
end
end
else
idxs = collect(structural_eachindex(duals, x))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
if isbitstype(V)
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
end
else
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
if isassigned(x, idx)
if supports_fast_scalar_indexing(duals)
if isbitstype(V)
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
else
Base._unsetindex!(duals, idx)
end
else
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
if isassigned(x, idx)
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
else
Base._unsetindex!(duals, idx)
end
end
end
else
idxs = collect(Iterators.take(structural_eachindex(duals, x), N))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
end
return duals
end
Expand All @@ -110,18 +120,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
offset = index - 1
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
if isbitstype(V)
for idx in idxs
duals[idx] = Dual{T,V,N}(x[idx], seed)
end
else
for idx in idxs
if isassigned(x, idx)
if supports_fast_scalar_indexing(duals)
if isbitstype(V)
for idx in idxs
duals[idx] = Dual{T,V,N}(x[idx], seed)
else
Base._unsetindex!(duals, idx)
end
else
for idx in idxs
if isassigned(x, idx)
duals[idx] = Dual{T,V,N}(x[idx], seed)
else
Base._unsetindex!(duals, idx)
end
end
end
else
idxs = collect(idxs)
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
end
return duals
end
Expand All @@ -130,18 +145,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
offset = index - 1
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
if isbitstype(V)
for (i, idx) in zip(1:chunksize, idxs)
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
end
else
for (i, idx) in zip(1:chunksize, idxs)
if isassigned(x, idx)
if supports_fast_scalar_indexing(duals)
if isbitstype(V)
for (i, idx) in zip(1:chunksize, idxs)
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
else
Base._unsetindex!(duals, idx)
end
else
for (i, idx) in zip(1:chunksize, idxs)
if isassigned(x, idx)
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
else
Base._unsetindex!(duals, idx)
end
end
end
else
idxs = collect(Iterators.take(idxs, chunksize))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
end
return duals
end
2 changes: 0 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ Base.copy(d::Dual) = d
Base.eps(d::Dual) = eps(value(d))
Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))

# The `base` keyword was added in Julia 1.8:
# https://github.com/JuliaLang/julia/pull/42428
Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base)
function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual}
precision(valtype(D); base=base)
Expand Down
24 changes: 18 additions & 6 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,29 @@ end
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
idxs = structural_eachindex(result)
for (i, idx) in zip(1:npartials(dual), idxs)
result[idx] = partials(T, dual, i)
if supports_fast_scalar_indexing(result)
for (i, idx) in zip(1:npartials(dual), idxs)
result[idx] = partials(T, dual, i)
end
else
fn = PartialsFn{T}(dual)
idxs = collect(Iterators.take(idxs, npartials(dual)))
result[idxs] .= fn.(1:length(idxs))
return result
end
return result
end

function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
offset = index - 1
idxs = Iterators.drop(structural_eachindex(result), offset)
for (i, idx) in zip(1:chunksize, idxs)
result[idx] = partials(T, dual, i)
idxs = Iterators.drop(structural_eachindex(result), index - 1)
if supports_fast_scalar_indexing(result)
for (i, idx) in zip(1:chunksize, idxs)
result[idx] = partials(T, dual, i)
end
else
fn = PartialsFn{T}(dual)
idxs = collect(Iterators.take(idxs, chunksize))
result[idxs] .= fn.(1:length(idxs))
end
return result
end
Expand Down
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# overload for array types that
supports_fast_scalar_indexing(::Array) = true

function supports_fast_scalar_indexing(x::AbstractArray)
return parent(x) !== x && supports_fast_scalar_indexing(parent(x))
end

# Helper function for broadcasting
struct PartialsFn{T,D<:Dual}
dual::D
end
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)

(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)
27 changes: 27 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
using DiffTests
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -255,4 +256,30 @@ end
end
end

@testset "GPUArraysCore" begin
fn(x) = sum(x .^ 2 ./ 2)

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

grad = ForwardDiff.gradient(fn, x)
grad_jl = ForwardDiff.gradient(fn, x_jl)

@test grad_jl isa JLArray
@test Array(grad_jl) ≈ grad

cfg = ForwardDiff.GradientConfig(
fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x))
)
grad_jl = ForwardDiff.gradient(fn, x_jl, cfg)

@test grad_jl isa JLArray
@test Array(grad_jl) ≈ grad
end

@testset "Scalar Indexing Checks" begin
@test ForwardDiff.supports_fast_scalar_indexing(UnitLowerTriangular(view(rand(6, 6), 1:3, 1:3)))
@test !ForwardDiff.supports_fast_scalar_indexing(UnitLowerTriangular(view(JLArray(rand(6, 6)), 1:3, 1:3)))
end

end # module
14 changes: 14 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig
using StaticArrays
using DiffTests
using LinearAlgebra
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -308,4 +309,17 @@ end
end
end

@testset "GPUArraysCore" begin
f(x) = x .^ 2 ./ 2

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

jac = ForwardDiff.jacobian(f, x)
jac_jl = ForwardDiff.jacobian(f, x_jl)

@test jac_jl isa JLArray
@test Array(jac_jl) ≈ jac
end

end # module
Loading