Skip to content
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "1.0.1"
version = "1.0.2"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand All @@ -15,9 +15,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
ForwardDiffGPUArraysCoreExt = "GPUArraysCore"
ForwardDiffStaticArraysExt = "StaticArrays"

[compat]
Expand All @@ -26,6 +28,7 @@ CommonSubexpressions = "0.3"
DiffResults = "1.1"
DiffRules = "1.4"
DiffTests = "0.1"
GPUArraysCore = "0.1, 0.2"
IrrationalConstants = "0.1, 0.2"
LogExpFunctions = "0.3"
NaNMath = "1"
Expand All @@ -39,9 +42,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"]
65 changes: 65 additions & 0 deletions ext/ForwardDiffGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module ForwardDiffGPUArraysCoreExt

using GPUArraysCore: AbstractGPUArray
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials

struct PartialsFn{T,D<:Dual}
dual::D
end
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)

(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
seed::Partials{N,V}) where {T,V,N}
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V}) where {T,V,N}
offset = index - 1
idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
offset = index - 1
idxs = collect(
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
)
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
return duals
end

# gradient
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
dual::Dual) where {T}
fn = PartialsFn{T}(dual)
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
result[idxs] .= fn.(1:length(idxs))
return result
end

function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual,
index, chunksize) where {T}
fn = PartialsFn{T}(dual)
offset = index - 1
idxs = collect(
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
)
result[idxs] .= fn.(1:length(idxs))
return result
end

end
2 changes: 0 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ Base.copy(d::Dual) = d
Base.eps(d::Dual) = eps(value(d))
Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))

# The `base` keyword was added in Julia 1.8:
# https://github.com/JuliaLang/julia/pull/42428
Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base)
function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual}
precision(valtype(D); base=base)
Expand Down
1 change: 0 additions & 1 deletion test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
@test precision(typeof(FDNUM)) === precision(V)
@test precision(NESTED_FDNUM) === precision(PRIMAL)
@test precision(typeof(NESTED_FDNUM)) === precision(V)

@test precision(FDNUM; base=10) === precision(PRIMAL; base=10)
@test precision(typeof(FDNUM); base=10) === precision(V; base=10)
@test precision(NESTED_FDNUM; base=10) === precision(PRIMAL; base=10)
Expand Down
22 changes: 22 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
using DiffTests
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -255,4 +256,25 @@ end
end
end

@testset "GPUArraysCore" begin
fn(x) = sum(x .^ 2 ./ 2)

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

grad = ForwardDiff.gradient(fn, x)
grad_jl = ForwardDiff.gradient(fn, x_jl)

@test grad_jl isa JLArray
@test Array(grad_jl) ≈ grad

cfg = ForwardDiff.GradientConfig(
fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x))
)
grad_jl = ForwardDiff.gradient(fn, x_jl, cfg)

@test grad_jl isa JLArray
@test Array(grad_jl) ≈ grad
end

end # module
14 changes: 14 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig
using StaticArrays
using DiffTests
using LinearAlgebra
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -279,4 +280,17 @@ end
end
end

@testset "GPUArraysCore" begin
f(x) = x .^ 2 ./ 2

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

jac = ForwardDiff.jacobian(f, x)
jac_jl = ForwardDiff.jacobian(f, x_jl)

@test jac_jl isa JLArray
@test Array(jac_jl) ≈ jac
end

end # module
Loading