diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 372a97de8..3654bf71f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.17.0" +ADTypes = "1.18.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 9e0510533..ed11dd10a 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -16,9 +16,13 @@ function check_available(backend::MixedMode) check_available(reverse_backend(backend)) end +check_available(::ADTypes.NoAutoDiff) = throw(ADTypes.NoAutoDiffSelectedError()) + """ check_inplace(backend) Check whether `backend` supports differentiation of in-place functions and return a `Bool`. """ check_inplace(backend::AbstractADType) = Bool(inplace_support(backend)) + +check_inplace(::ADTypes.NoAutoDiff) = throw(ADTypes.NoAutoDiffSelectedError()) diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 396e7d455..771d265c5 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -10,6 +10,9 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" +@test_throws NoAutoDiffSelectedError check_available(NoAutoDiff()) +@test_throws NoAutoDiffSelectedError check_inplace(NoAutoDiff()) + zero_backends = [AutoZeroForward(), AutoZeroReverse()] for backend in zero_backends