-
Notifications
You must be signed in to change notification settings - Fork 239
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from devito import Function, Eq | ||
from devito.symbolics import uxreplace | ||
from sympy import Basic | ||
|
||
|
||
class MultiStage(Basic): | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def __new__(cls, eq, method): | ||
assert isinstance(eq, Eq) | ||
return Basic.__new__(cls, eq, method) | ||
|
||
@property | ||
def eq(self): | ||
return self.args[0] | ||
|
||
@property | ||
def method(self): | ||
return self.args[1] | ||
|
||
|
||
class RK(Basic): | ||
""" | ||
A class representing an explicit Runge-Kutta method via its Butcher tableau. | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
a : list[list[float]] | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Lower-triangular coefficient matrix (stage dependencies). | ||
b : list[float] | ||
Weights for the final combination step. | ||
c : list[float] | ||
Weights for the stages time step. | ||
""" | ||
|
||
def __init__(self, a, b, c): | ||
self.a = a | ||
self.b = b | ||
self.c = c | ||
self.s = len(b) # number of stages | ||
|
||
|
||
self._validate() | ||
|
||
def _validate(self): | ||
assert len(self.a) == self.s, "'a' must have s rows" | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
for i, row in enumerate(self.a): | ||
assert len(row) == i, f"Row {i} in 'a' must have {i} entries for explicit RK" | ||
|
||
def expand_stages(self, base_eq, eq_num=0): | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
Expand a single Eq into a list of stage-wise Eqs for this RK method. | ||
Parameters | ||
---------- | ||
base_eq : Eq | ||
The equation Eq(u.forward, rhs) to be expanded into RK stages. | ||
eq_number : integer, optional | ||
|
||
The equation number to idetify the k_i's stages | ||
Returns | ||
------- | ||
list of Eq | ||
Stage-wise equations: [k0=..., k1=..., ..., u.forward=...] | ||
""" | ||
u = base_eq.lhs.function | ||
rhs = base_eq.rhs | ||
grid = u.grid | ||
dt = grid.stepping_dim.spacing | ||
t = grid.time_dim | ||
|
||
|
||
# Create temporary Functions to hold each stage | ||
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
|
||
for i in range(self.s)] | ||
|
||
stage_eqs = [] | ||
|
||
# Build each stage | ||
for i in range(self.s): | ||
u_temp = u | ||
for j in range(i): | ||
if self.a[i][j] != 0: | ||
u_temp += self.a[i][j] * dt * k[j] | ||
t_shift = t + self.c[i] * dt | ||
|
||
# Evaluate RHS at intermediate value | ||
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift}) | ||
stage_eqs.append(Eq(k[i], stage_rhs)) | ||
|
||
# Final update: u.forward = u + dt * sum(bᵢ * kᵢ) | ||
u_next = u | ||
for i in range(self.s): | ||
u_next += self.b[i] * dt * k[i] | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
stage_eqs.append(Eq(u.forward, u_next)) | ||
|
||
return stage_eqs | ||
|
||
# ---- Named methods for convenience ---- | ||
@classmethod | ||
def RK44(cls): | ||
"""Classical Runge-Kutta of 4 stages and 4th order""" | ||
a = [ | ||
[], | ||
[1 / 2], | ||
|
||
[0, 1 / 2], | ||
[0, 0, 1] | ||
] | ||
b = [1 / 6, 1 / 3, 1 / 3, 1 / 6] | ||
c = [0, 1 / 2, 1 / 2, 1] | ||
return cls(a, b, c) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
disk_layer) | ||
from devito.types.dimension import Thickness | ||
|
||
from devito.operator.new_classes import MultiStage | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run the linter ( |
||
__all__ = ['Operator'] | ||
|
||
|
@@ -184,8 +185,9 @@ def _sanitize_exprs(cls, expressions, **kwargs): | |
expressions = as_tuple(expressions) | ||
|
||
for i in expressions: | ||
if not isinstance(i, Evaluable): | ||
raise CompilationError(f"`{i!s}` is not an Evaluable object; " | ||
i_check = i.eq if isinstance(i, MultiStage) else i | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if not isinstance(i_check, Evaluable): | ||
raise CompilationError(f"`{i_check!s}` is not an Evaluable object; " | ||
"check your equation again") | ||
|
||
return expressions | ||
|
@@ -271,6 +273,9 @@ def _lower(cls, expressions, **kwargs): | |
# expression for which a partial or complete lowering is desired | ||
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) | ||
|
||
# [MultiStage] -> [Eqs] | ||
expressions = cls._lower_multistage(expressions, **kwargs) | ||
|
||
|
||
# [Eq] -> [LoweredEq] | ||
expressions = cls._lower_exprs(expressions, **kwargs) | ||
|
||
|
@@ -314,6 +319,27 @@ def _specialize_exprs(cls, expressions, **kwargs): | |
""" | ||
return expressions | ||
|
||
@classmethod | ||
@timed_pass(name='lowering.MultiStages') | ||
def _lower_multistage(cls, expressions, **kwargs): | ||
|
||
""" | ||
Separating the multi-stage time-integrator scheme in stages: | ||
|
||
* Check if the time-integrator is Multistage; | ||
* Creating the stages of the method. | ||
""" | ||
|
||
lowered = [] | ||
for i, eq in enumerate(as_tuple(expressions)): | ||
if isinstance(eq, MultiStage): | ||
|
||
time_int = eq.method | ||
stage_eqs = time_int.expand_stages(eq.eq, eq_num=i) | ||
lowered.extend(stage_eqs) | ||
else: | ||
lowered.append(eq) | ||
|
||
return lowered | ||
|
||
@classmethod | ||
@timed_pass(name='lowering.Expressions') | ||
def _lower_exprs(cls, expressions, **kwargs): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import sympy as sym | ||
import matplotlib.pyplot as plt | ||
|
||
import devito as dv | ||
fernanvr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
from examples.seismic import Receiver, TimeAxis | ||
|
||
from devito.operator.new_classes import RK, MultiStage | ||
# Set logging level for debugging | ||
dv.configuration['log-level'] = 'DEBUG' | ||
|
||
# Parameters | ||
|
||
space_order = 2 | ||
fd_order = 2 | ||
extent = (1000, 1000) | ||
shape = (201, 201) | ||
origin = (0, 0) | ||
|
||
# Grid setup | ||
grid = dv.Grid(origin=origin, extent=extent, shape=shape, dtype=np.float64) | ||
x, y = grid.dimensions | ||
dt = grid.stepping_dim.spacing | ||
t = grid.time_dim | ||
dx = extent[0] / (shape[0] - 1) | ||
|
||
# Medium velocity model | ||
vel = dv.Function(name="vel", grid=grid, space_order=space_order, dtype=np.float64) | ||
vel.data[:] = 1.0 | ||
vel.data[150:, :] = 1.3 | ||
|
||
# Define wavefield unknowns: u (displacement) and v (velocity) | ||
fun_labels = ['u', 'v'] | ||
U = [dv.TimeFunction(name=name, grid=grid, space_order=space_order, | ||
time_order=1, dtype=np.float64) for name in fun_labels] | ||
|
||
# Time axis | ||
t0, tn = 0.0, 500.0 | ||
dt0 = np.max(vel.data) / dx**2 | ||
nt = int((tn - t0) / dt0) | ||
dt0 = tn / nt | ||
time_range = TimeAxis(start=t0, stop=tn, num=nt + 1) | ||
|
||
# Receiver setup | ||
rec = Receiver(name='rec', grid=grid, npoint=3, time_range=time_range) | ||
rec.coordinates.data[:, 0] = np.linspace(0, 1, 3) | ||
rec.coordinates.data[:, 1] = 0.5 | ||
rec = rec.interpolate(expr=U[0].forward) | ||
|
||
# Source definition | ||
src_spatial = dv.Function(name="src_spat", grid=grid, space_order=space_order, dtype=np.float64) | ||
src_spatial.data[100, 100] = 1 / dx**2 | ||
|
||
f0 = 0.01 | ||
src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1/f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1/f0))**2) | ||
|
||
# PDE system (2D acoustic) | ||
system_eqs = [U[1], | ||
|
||
(dv.Derivative(U[0], (x, 2), fd_order=fd_order) + | ||
dv.Derivative(U[0], (y, 2), fd_order=fd_order) + | ||
src_spatial * src_temporal) * vel**2] | ||
|
||
# Time integration scheme | ||
rk = RK.RK44() | ||
|
||
# MultiStage object | ||
pdes = [MultiStage(dv.Eq(U[i], system_eqs[i]), rk) for i in range(2)] | ||
|
||
# Construct and run operator | ||
op = dv.Operator(pdes + [rec], subs=grid.spacing_map) | ||
op(dt=dt0, time=nt) | ||
|
||
|
||
# Plot final wavefield | ||
plt.imshow(U[0].data[1, :], cmap="seismic") | ||
plt.colorbar(label="Amplitude") | ||
plt.title("Wavefield snapshot (t = final)") | ||
plt.xlabel("x") | ||
plt.ylabel("y") | ||
plt.tight_layout() | ||
plt.show() |
Uh oh!
There was an error while loading. Please reload this page.