diff --git a/Project.toml b/Project.toml index 6a15eed36..32128bb15 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" @@ -98,6 +99,7 @@ Printf = "1.10" QuasiMonteCarlo = "0.3.2" Random = "1" ReTestItems = "1.29.0" +Reactant = "0.2.152" RecursiveArrayTools = "3.27.0" Reexport = "1.2" RuntimeGeneratedFunctions = "0.5.12" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index b5735ec71..7aeb70a14 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -24,6 +24,7 @@ using Optimization: Optimization using OptimizationOptimisers: OptimizationOptimisers using Printf: @printf using Random: Random, AbstractRNG +using Reactant: Reactant, reactant_device using RecursiveArrayTools: DiffEqArray using Reexport: @reexport using RuntimeGeneratedFunctions: RuntimeGeneratedFunctions, @RuntimeGeneratedFunction @@ -61,7 +62,7 @@ abstract type AbstractPINN end abstract type AbstractTrainingStrategy end -const cdev = CPUDevice() +const cdev = CPUDevice() @inline safe_get_device(x) = safe_get_device(get_device(x), x) @inline safe_get_device(::Nothing, x) = cdev diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 8cdd4a087..5dd63854f 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -28,6 +28,7 @@ standard `DAEProblem`. * `strategy`: The training strategy used to choose the points for the evaluations. By default, `GridTraining` is used with `dt` if given. """ +const ydev = reactant_device() @concrete struct NNDAE <: SciMLBase.AbstractDAEAlgorithm chain <: AbstractLuxLayer opt @@ -88,7 +89,7 @@ function SciMLBase.__solve( t0 = tspan[1] (; chain, opt, autodiff, init_params) = alg - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + phi, init_params = generate_phi_θ(chain, t0, ydev(u0), ydev(init_params)) init_params = ComponentArray(; depvar = init_params) @assert !isinplace(prob) "The NNODE solver only supports out-of-place DAE definitions, i.e. du=f(u,p,t)." diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 282dd35fb..319dad500 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -88,6 +88,7 @@ Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neura for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. """ +const xdev = reactant_device() @concrete struct NNODE chain <: AbstractLuxLayer opt @@ -137,12 +138,11 @@ function generate_phi_θ(chain::AbstractLuxLayer, t, u0, init_params) end function (f::ODEPhi)(t, θ) - dev = safe_get_device(θ) - return f(dev, safe_expand(dev, t), θ) + return f(xdev, safe_expand(xdev, t), θ) end function (f::ODEPhi{<:Number})(dev, t::Number, θ) - res = only(cdev(f.smodel(dev([t]), θ.depvar))) + res = only(f.smodel(xdev([t]), θ.depvar)) return f.u0 + (t - f.t0) * res end @@ -361,7 +361,7 @@ function SciMLBase.__solve( (; param_estim, estim_collocate, dataset, chain, opt, autodiff, init_params, batch, additional_loss, estim_collocate) = alg - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + phi, init_params = generate_phi_θ(chain, t0, xdev(u0), xdev(init_params)) (recursive_eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") @@ -471,7 +471,7 @@ function SciMLBase.__solve( else u = [phi(t, res.u) for t in ts] end - + sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true, interp = NNODEInterpolation(phi, res.u), calculate_error = false, retcode = ReturnCode.Success, original = res, resid = res.objective)