From b231cf60decd530e4a58d11395ad83f4f1621c79 Mon Sep 17 00:00:00 2001 From: Sharv Date: Wed, 4 Jun 2025 20:30:48 -0700 Subject: [PATCH 1/7] Fix math formatting for SciMLSensitivity --- docs/src/examples/pde/brusselator.md | 67 ++++++++++++++-------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index b57e2b001..cd6c724b2 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -8,55 +8,54 @@ The Brusselator is a mathematical model used to describe oscillating chemical re The Brusselator PDE is defined on a unit square periodic domain as follows: -$$ -\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t) -$$ +```math +\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t) +``` -$$ -\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 V -$$ +```math +\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 +``` where $A=3.4, B=1$ and the forcing term is: -$$ +```math f(x, y, t) = \begin{cases} 5 & \text{if } (x - 0.3)^2 + (y - 0.6)^2 \leq 0.1^2 \text{ and } t \geq 1.1 \\ 0 & \text{otherwise} \end{cases} -$$ +``` and the Laplacian operator is: -$$ +```math \nabla^2 = \frac{\partial^2}{\partial x^2} + \frac{\partial^2}{\partial y^2} -$$ +``` These equations are solved over the time interval: -$$ +```math t \in [0, 11.5] -$$ +``` with the initial conditions: -$$ -U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2} -$$ +```math +U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2} +``` -$$ +```math V(x, y, 0) = 27 \cdot \left( x(1 - x) \right)^{3/2} -$$ +``` and the periodic boundary conditions: -$$ -U(x + 1, y, t) = U(x, y, t) -$$ - -$$ +```math +U(x + 1, y, t) = U(x, y, t) +``` +```math V(x, y + 1, t) = V(x, y, t) -$$ +``` ## Numerical Discretization @@ -64,15 +63,15 @@ To numerically solve this PDE, we discretize the unit square domain using $N$ gr We represent the spatially discretized fields as: -$$ +```math U[i,j] = U(i \cdot \Delta x, j \cdot \Delta y), \quad V[i,j] = V(i \cdot \Delta x, j \cdot \Delta y), -$$ +``` where $\Delta x = \Delta y = \frac{1}{N}$ for a grid of size $N \times N$. To organize the simulation state efficiently, we store both $ U $ and $ V $ in a single 3D array: -$$ +```math u[i,j,1] = U[i,j], \quad u[i,j,2] = V[i,j], -$$ +``` giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and extends naturally to systems with additional field variables. @@ -81,11 +80,11 @@ giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and ex For spatial derivatives, we apply a second-order central difference scheme using a three-point stencil. The Laplacian is discretized as: -$$ +```math [\ 1,\ -2,\ 1\ ] -$$ +``` -in both the $ x $ and $ y $ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature. +in both the $x$ and $y$ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature. ## Generating Training Data @@ -164,13 +163,13 @@ In the original Brusselator model, the nonlinear reaction term \( U^2V \) govern The resulting system becomes: -$$ +```math \frac{\partial U}{\partial t} = 1 + \mathcal{N}_\theta(U, V) - 4.4U + \alpha \nabla^2 U + f(x, y, t) -$$ +``` -$$ +```math \frac{\partial V}{\partial t} = 3.4U - \mathcal{N}_\theta(U, V) + \alpha \nabla^2 V -$$ +``` Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction term $U^2V$ using simulation data. This hybrid formulation allows us to recover unknown or partially known physical processes while preserving the known structural components of the PDE. From 92355821c3b6b4e2bc57185254aa741d31c97773 Mon Sep 17 00:00:00 2001 From: Sharv Date: Thu, 26 Jun 2025 19:38:30 -0700 Subject: [PATCH 2/7] Add multiple shooting to Brusselator UDE example This update enhances the 2D Brusselator PDE Universal Differential Equation (UDE) example by implementing a multiple shooting training strategy. --- docs/src/examples/pde/brusselator.md | 125 +++++++++++++++++---------- 1 file changed, 80 insertions(+), 45 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index cd6c724b2..221caa3fe 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -1,4 +1,4 @@ -# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations +# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations and Multiple Shooting ## Introduction @@ -58,7 +58,7 @@ V(x, y + 1, t) = V(x, y, t) ``` ## Numerical Discretization - +f To numerically solve this PDE, we discretize the unit square domain using $N$ grid points along each spatial dimension. The variables $U[i,j]$ and $V[i,j]$ then denote the concentrations at the grid point $(i, j)$ at a given time $t$. We represent the spatially discretized fields as: @@ -91,33 +91,35 @@ in both the $x$ and $y$ directions, forming a tridiagonal structure in both the This provides us with an `ODEProblem` that can be solved to obtain training data. ```@example bruss -using ComponentArrays, Random, Plots, OrdinaryDiffEq +using ComponentArrays, Random, Plots, OrdinaryDiffEq, Statistics +using Lux, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote, OptimizationOptimisers +# Grid and Time Setup N_GRID = 16 XYD = range(0f0, stop = 1f0, length = N_GRID) dx = step(XYD) T_FINAL = 11.5f0 SAVE_AT = 0.5f0 tspan = (0.0f0, T_FINAL) -t_points = range(tspan[1], stop=tspan[2], step=SAVE_AT) +t_points = collect(range(tspan[1], stop=tspan[2], step=SAVE_AT)) A, B, alpha = 3.4f0, 1.0f0, 10.0f0 -brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0 +# Helper Functions limit(a, N) = a == 0 ? N : a == N+1 ? 1 : a +brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0 + function init_brusselator(xyd) - println("[Init] Creating initial condition array...") u0 = zeros(Float32, N_GRID, N_GRID, 2) for I in CartesianIndices((N_GRID, N_GRID)) x, y = xyd[I[1]], xyd[I[2]] u0[I,1] = 22f0 * (y * (1f0 - y))^(3f0/2f0) u0[I,2] = 27f0 * (x * (1f0 - x))^(3f0/2f0) end - println("[Init] Done.") return u0 end -u0 = init_brusselator(XYD) +# Ground Truth PDE function pde_truth!(du, u, p, t) A, B, alpha, dx = p αdx = alpha / dx^2 @@ -134,8 +136,9 @@ function pde_truth!(du, u, p, t) end end +u0 = init_brusselator(XYD) p_tuple = (A, B, alpha, dx) -@time sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points) +sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points) u_true = Array(sol_truth) ``` @@ -143,8 +146,6 @@ u_true = Array(sol_truth) We can now use this code for training our UDE, and generating time-series plots of the concentrations of species of U and V using the code: ```@example bruss -using Plots, Statistics - # Compute average concentration at each timestep avg_U = [mean(snapshot[:, :, 1]) for snapshot in sol_truth.u] avg_V = [mean(snapshot[:, :, 2]) for snapshot in sol_truth.u] @@ -157,11 +158,7 @@ plot!(sol_truth.t, avg_V, label="Mean V", lw=2, linestyle=:dash) With the ground truth data generated and visualized, we are now ready to construct a Universal Differential Equation (UDE) by replacing the nonlinear term $U^2V$ with a neural network. The next section outlines how we define this hybrid model and train it to recover the reaction dynamics from data. -## Universal Differential Equation (UDE) Formulation - -In the original Brusselator model, the nonlinear reaction term \( U^2V \) governs key dynamic behavior. In our UDE approach, we replace this known term with a trainable neural network \( \mathcal{N}_\theta(U, V) \), where \( \theta \) are the learnable parameters. - -The resulting system becomes: +## Universal Differential Equation (UDE) Formulation with Multiple Shooting ```math \frac{\partial U}{\partial t} = 1 + \mathcal{N}_\theta(U, V) - 4.4U + \alpha \nabla^2 U + f(x, y, t) @@ -174,10 +171,7 @@ The resulting system becomes: Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction term $U^2V$ using simulation data. This hybrid formulation allows us to recover unknown or partially known physical processes while preserving the known structural components of the PDE. First, we have to define and configure the neural network that has to be used for the training. The implementation for that is as follows: - ```@example bruss -using Lux, Random, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote - model = Lux.Chain(Dense(2 => 16, tanh), Dense(16 => 1)) rng = Random.default_rng() ps_init, st = Lux.setup(rng, model) @@ -215,41 +209,83 @@ function pde_ude!(du, u, ps_nn, t) end prob_ude_template = ODEProblem(pde_ude!, u0, tspan, ps_init) ``` -## Loss Function and Optimization -To train the neural network -$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we define a loss function that measures how closely the solution of the UDE matches the ground truth data generated earlier. -The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use ```FBDF()``` as the solver due to the stiff nature of the brusselators euqation. Other solvers like ```KenCarp47()``` could also be used. +### Multiple Shooting +Traditional single-shooting training for stiff PDEs like the Brusselator often leads to instability or suboptimal learning due to long simulation horizons. Multiple shooting mitigates this by dividing the overall time span into shorter, manageable segments. This: -To efficiently compute gradients of the loss with respect to the neural network parameters, we use an adjoint sensitivity method (`GaussAdjoint`), which performs high-accuracy quadrature-based integration of the adjoint equations. This approach enables scalable and memory-efficient training for stiff PDEs by avoiding full trajectory storage while maintaining accurate gradient estimates. +* Prevents error accumulation, +* Encourages better generalization, +* And enforces continuity between segments. -The loss function and initial evaluation are implemented as follows: +First, we have to conduct the time segmentation: +```@example bruss +n_segments = 5 +segment_times = range(tspan[1], stop=tspan[2], length=n_segments+1) +segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments] +segment_saves = [collect(range(t[1], stop=t[2], step=SAVE_AT)) for t in segment_spans] +``` +We also compute the indices in the original `t_points` that correspond to each segment: ```@example bruss -println("[Loss] Defining loss function...") -function loss_fn(ps, _) - prob = remake(prob_ude_template, p=ps) - sol = solve(prob, FBDF(), saveat=t_points) - # Failed solve - if !SciMLBase.successful_retcode(sol) - return Inf32 - end - pred = Array(sol) - lval = sum(abs2, pred .- u_true) / length(u_true) - return lval +function match_time_indices(t_points, segment_saves) + return [map(ti -> findmin(abs.(t_points .- ti))[2], segment_saves[i]) for i in 1:length(segment_saves)] +end + +segment_time_indices = match_time_indices(t_points, segment_saves) +``` + +Then, we create an individual problem for each segment: +```@example bruss +function get_segment_prob(ps, u0_seg, seg_idx) + remake(prob_ude_template, u0=u0_seg, tspan=segment_spans[seg_idx], p=ps) end ``` -Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```: +#### Loss Function and Optimization +To train the neural network +$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we implement a multiple shooting loss function that segments the full simulation into smaller time intervals and enforces temporal consistency across them. + +For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization. + +Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation. + +This approach improves training robustness by constraining long-term predictions and encouraging accurate short-term learning within each segment. The final optimization is carried out using the `ADAM` algorithm over all neural network parameters. + +The loss function is defined below: ```@example bruss -println("[Training] Starting optimization...") -using OptimizationOptimisers -optf = OptimizationFunction(loss_fn, AutoZygote()) +function loss_fn_multi(ps, _) + total_loss = 0f0 + u0_seg = copy(u0) + for i in 1:n_segments + prob_i = get_segment_prob(ps, u0_seg, i) + sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) + if !SciMLBase.successful_retcode(sol_i) + return Inf32 + end + pred_i = Array(sol_i) + t_idxs = segment_time_indices[i] + println("Segment $i: matched indices = ", t_idxs) + if isempty(t_idxs) + error("No matching time points for segment $i — check SAVE_AT, t_points, or tolerance.") + end + true_i = u_true[:,:,:,t_idxs] + total_loss += sum(abs2, pred_i .- true_i) / length(true_i) + if i < n_segments + u0_seg = pred_i[:,:,:,end] + next_u0 = u_true[:,:,:,t_idxs[end]+1] + total_loss += sum(abs2, u0_seg .- next_u0) / length(next_u0) + end + end + return total_loss +end +``` +Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```: +```@example bruss +optf = OptimizationFunction(loss_fn_multi, AutoZygote()) optprob = OptimizationProblem(optf, ps_init) loss_history = Float32[] - callback = (ps, l) -> begin push!(loss_history, l) println("Epoch $(length(loss_history)): Loss = $l") @@ -260,7 +296,7 @@ end Finally to run everything: ```@example bruss -res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=100) +res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=5000) ``` ```@example bruss @@ -268,7 +304,6 @@ res.objective ``` ```@example bruss -println("[Plot] Final U/V comparison plots...") center = N_GRID ÷ 2 sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points) pred = Array(sol_final) @@ -286,6 +321,6 @@ plot(p1, p2, layout=(1,2), size=(900,400)) ## Results and Conclusion -After training the Universal Differential Equation (UDE), we compared the predicted dynamics to the ground truth for both chemical species. +After training the Universal Differential Equation (UDE) using the multiple shooting strategy, we compared the predicted dynamics to the ground truth for both chemical species. -The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation. +The low training loss across segments demonstrates that the neural network was able to accurately capture the underlying reaction dynamics. The model effectively learned the nonlinear $U^2V$ term through a segment-wise optimization process that enforces both data fidelity and inter-segment continuity. This confirms that multiple shooting not only stabilizes training but also enhances temporal consistency in learning complex spatiotemporal PDE systems. \ No newline at end of file From 12d35c1bbd15c9223118b51c3bd38ee18c18f660 Mon Sep 17 00:00:00 2001 From: Sharv Date: Sun, 6 Jul 2025 14:53:14 -0700 Subject: [PATCH 3/7] Update brusselator.md --- docs/src/examples/pde/brusselator.md | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index 221caa3fe..de1d55626 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -222,7 +222,14 @@ First, we have to conduct the time segmentation: n_segments = 5 segment_times = range(tspan[1], stop=tspan[2], length=n_segments+1) segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments] -segment_saves = [collect(range(t[1], stop=t[2], step=SAVE_AT)) for t in segment_spans] +segment_saves = [ + let ts = collect(range(t1, stop=t2, step=SAVE_AT)) + if ts[end] != t2 + push!(ts, t2) + end + ts + end for (t1,t2) in segment_spans +] ``` We also compute the indices in the original `t_points` that correspond to each segment: @@ -247,19 +254,20 @@ $\mathcal{N}_\theta(U, V)$ embedded in the UDE, we implement a multiple shooting For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization. -Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation. +Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation. -This approach improves training robustness by constraining long-term predictions and encouraging accurate short-term learning within each segment. The final optimization is carried out using the `ADAM` algorithm over all neural network parameters. +This approach improves training robustness by constraining long-term predictions and encouraging accurate short-term learning within each segment. The final optimization is carried out using the `ADAM` algorithm over all neural network parameters. The loss function is defined below: ```@example bruss +const λ = 5f0 function loss_fn_multi(ps, _) total_loss = 0f0 u0_seg = copy(u0) for i in 1:n_segments prob_i = get_segment_prob(ps, u0_seg, i) - sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) + sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i], abstol=1e-6, reltol=1e-6) if !SciMLBase.successful_retcode(sol_i) return Inf32 end @@ -272,9 +280,11 @@ function loss_fn_multi(ps, _) true_i = u_true[:,:,:,t_idxs] total_loss += sum(abs2, pred_i .- true_i) / length(true_i) if i < n_segments - u0_seg = pred_i[:,:,:,end] - next_u0 = u_true[:,:,:,t_idxs[end]+1] - total_loss += sum(abs2, u0_seg .- next_u0) / length(next_u0) + t_end = segment_spans[i][2] + pred_end = sol_i(t_end) + true_end = sol_truth(t_end) + total_loss += λ * sum(abs2, pred_end .- true_end) / length(true_end) + u0_seg = pred_end end end return total_loss @@ -305,7 +315,7 @@ res.objective ```@example bruss center = N_GRID ÷ 2 -sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points) +sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points, abstol=1e-6, reltol=1e-6) pred = Array(sol_final) p1 = plot(t_points, u_true[center,center,1,:], lw=2, label="U True") From 38ca0a5a66ce890a9ad1638c7b4cec685db204af Mon Sep 17 00:00:00 2001 From: Sharv Date: Sun, 6 Jul 2025 15:38:11 -0700 Subject: [PATCH 4/7] Update brusselator.md --- docs/src/examples/pde/brusselator.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index de1d55626..4acc83825 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -280,11 +280,9 @@ function loss_fn_multi(ps, _) true_i = u_true[:,:,:,t_idxs] total_loss += sum(abs2, pred_i .- true_i) / length(true_i) if i < n_segments - t_end = segment_spans[i][2] - pred_end = sol_i(t_end) - true_end = sol_truth(t_end) - total_loss += λ * sum(abs2, pred_end .- true_end) / length(true_end) - u0_seg = pred_end + u0_seg = pred_i[:,:,:,end] + next_u0 = u_true[:,:,:,t_idxs[end]+1] + total_loss += λ * sum(abs2, u0_seg .- next_u0) / length(next_u0) end end return total_loss @@ -316,6 +314,7 @@ res.objective ```@example bruss center = N_GRID ÷ 2 sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points, abstol=1e-6, reltol=1e-6) + pred = Array(sol_final) p1 = plot(t_points, u_true[center,center,1,:], lw=2, label="U True") From b5ab4224dae0e9969f545d5a9db13f0b215ae288 Mon Sep 17 00:00:00 2001 From: Sharv Date: Sun, 6 Jul 2025 22:05:24 -0700 Subject: [PATCH 5/7] Update brusselator.md --- docs/src/examples/pde/brusselator.md | 58 +++++++++++++--------------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index 4acc83825..190356790 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -219,26 +219,13 @@ Traditional single-shooting training for stiff PDEs like the Brusselator often l First, we have to conduct the time segmentation: ```@example bruss +# ---------------------------- Multiple Shooting ---------------------------- # +N = length(t_points) # 24 points: 0.0,0.5,…,11.5 n_segments = 5 -segment_times = range(tspan[1], stop=tspan[2], length=n_segments+1) -segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments] -segment_saves = [ - let ts = collect(range(t1, stop=t2, step=SAVE_AT)) - if ts[end] != t2 - push!(ts, t2) - end - ts - end for (t1,t2) in segment_spans -] -``` - -We also compute the indices in the original `t_points` that correspond to each segment: -```@example bruss -function match_time_indices(t_points, segment_saves) - return [map(ti -> findmin(abs.(t_points .- ti))[2], segment_saves[i]) for i in 1:length(segment_saves)] -end - -segment_time_indices = match_time_indices(t_points, segment_saves) +ends = round.(Int, LinRange(1, N, n_segments+1)) +segment_time_indices = [ ends[i]:ends[i+1] for i in 1:n_segments ] +segment_saves = [ t_points[idxs] for idxs in segment_time_indices ] +segment_saves = [ t_points[segment_time_indices[i]] for i in 1:length(segment_spans) ] ``` Then, we create an individual problem for each segment: @@ -261,30 +248,37 @@ This approach improves training robustness by constraining long-term predictions The loss function is defined below: ```@example bruss -const λ = 5f0 function loss_fn_multi(ps, _) total_loss = 0f0 - u0_seg = copy(u0) + u0_seg = copy(u0) + for i in 1:n_segments + # Build & solve the i-th segment problem prob_i = get_segment_prob(ps, u0_seg, i) - sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i], abstol=1e-6, reltol=1e-6) + sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) if !SciMLBase.successful_retcode(sol_i) return Inf32 end - pred_i = Array(sol_i) - t_idxs = segment_time_indices[i] - println("Segment $i: matched indices = ", t_idxs) - if isempty(t_idxs) - error("No matching time points for segment $i — check SAVE_AT, t_points, or tolerance.") - end - true_i = u_true[:,:,:,t_idxs] + + # Extract prediction & truth at the exact same grid-times + pred_i = Array(sol_i) # dims: (nx,ny,2,Ni) + idxs = segment_time_indices[i] # e.g. 1:7 or 7:12 + true_i = u_true[:,:,:, idxs] # same dims + + # 1) data-fit loss on this segment total_loss += sum(abs2, pred_i .- true_i) / length(true_i) + + # 2) continuity penalty at segment boundary if i < n_segments - u0_seg = pred_i[:,:,:,end] - next_u0 = u_true[:,:,:,t_idxs[end]+1] - total_loss += λ * sum(abs2, u0_seg .- next_u0) / length(next_u0) + # last time‐slice of this segment + u0_seg = pred_i[:,:,:, end] + + # truth at that same last slice + boundary_truth = true_i[:,:,:, end] + total_loss += sum(abs2, u0_seg .- boundary_truth) / length(boundary_truth) end end + return total_loss end ``` From 144419ca9c3c8f37fc9ffce163735c6130fedb33 Mon Sep 17 00:00:00 2001 From: Sharv Date: Wed, 9 Jul 2025 15:18:58 -0700 Subject: [PATCH 6/7] Fix segment times --- docs/src/examples/pde/brusselator.md | 61 +++++++++++++++------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index 190356790..2d464e139 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -219,13 +219,21 @@ Traditional single-shooting training for stiff PDEs like the Brusselator often l First, we have to conduct the time segmentation: ```@example bruss -# ---------------------------- Multiple Shooting ---------------------------- # -N = length(t_points) # 24 points: 0.0,0.5,…,11.5 -n_segments = 5 -ends = round.(Int, LinRange(1, N, n_segments+1)) -segment_time_indices = [ ends[i]:ends[i+1] for i in 1:n_segments ] -segment_saves = [ t_points[idxs] for idxs in segment_time_indices ] -segment_saves = [ t_points[segment_time_indices[i]] for i in 1:length(segment_spans) ] +segment_duration = 2.5f0 # 5 steps of SAVE_AT +n_segments = floor(Int, T_FINAL / segment_duration) # This will calculate n_segments = 4 + +# Create segments based on the duration, not a fixed number +segment_times = range(tspan[1], step=segment_duration, length=n_segments + 1) +segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments] + +# The rest of the code remains the same +segment_saves = [collect(range(t[1], stop=t[2], step=SAVE_AT)) for t in segment_spans] + +function match_time_indices(t_points, segment_saves) + return [map(ti -> findmin(abs.(t_points .- ti))[2], segment_saves[i]) for i in 1:length(segment_saves)] +end + +segment_time_indices = match_time_indices(t_points, segment_saves) ``` Then, we create an individual problem for each segment: @@ -239,7 +247,7 @@ end To train the neural network $\mathcal{N}_\theta(U, V)$ embedded in the UDE, we implement a multiple shooting loss function that segments the full simulation into smaller time intervals and enforces temporal consistency across them. -For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization. +For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty ($\lambda$) that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization. Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation. @@ -248,37 +256,30 @@ This approach improves training robustness by constraining long-term predictions The loss function is defined below: ```@example bruss +λ = 10.0f0 function loss_fn_multi(ps, _) total_loss = 0f0 - u0_seg = copy(u0) - + u0_seg = copy(u0) for i in 1:n_segments - # Build & solve the i-th segment problem prob_i = get_segment_prob(ps, u0_seg, i) - sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) + sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) if !SciMLBase.successful_retcode(sol_i) return Inf32 end - - # Extract prediction & truth at the exact same grid-times - pred_i = Array(sol_i) # dims: (nx,ny,2,Ni) - idxs = segment_time_indices[i] # e.g. 1:7 or 7:12 - true_i = u_true[:,:,:, idxs] # same dims - - # 1) data-fit loss on this segment + pred_i = Array(sol_i) + t_idxs = segment_time_indices[i] + println("Segment $i: matched indices = ", t_idxs) + if isempty(t_idxs) + error("No matching time points for segment $i — check SAVE_AT, t_points, or tolerance.") + end + true_i = u_true[:,:,:,t_idxs] total_loss += sum(abs2, pred_i .- true_i) / length(true_i) - - # 2) continuity penalty at segment boundary if i < n_segments - # last time‐slice of this segment - u0_seg = pred_i[:,:,:, end] - - # truth at that same last slice - boundary_truth = true_i[:,:,:, end] - total_loss += sum(abs2, u0_seg .- boundary_truth) / length(boundary_truth) + u0_seg = pred_i[:,:,:,end] + next_u0 = u_true[:,:,:,t_idxs[end]+1] + total_loss += λ * sum(abs2, u0_seg .- next_u0) / length(next_u0) end end - return total_loss end ``` @@ -288,9 +289,11 @@ optf = OptimizationFunction(loss_fn_multi, AutoZygote()) optprob = OptimizationProblem(optf, ps_init) loss_history = Float32[] +epoch_counter = Ref(0) callback = (ps, l) -> begin + epoch_counter[] += 1 push!(loss_history, l) - println("Epoch $(length(loss_history)): Loss = $l") + println("Epoch $(epoch_counter[]): Loss = $l") false end ``` From 362065879d46a03a7247f7622bbc226073164fe7 Mon Sep 17 00:00:00 2001 From: Sharv Date: Fri, 1 Aug 2025 22:16:40 +0530 Subject: [PATCH 7/7] Update brusselator.md --- docs/src/examples/pde/brusselator.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index 2d464e139..48503b8eb 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -301,7 +301,7 @@ end Finally to run everything: ```@example bruss -res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=5000) +res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=10000) ``` ```@example bruss