Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
044b14e
dsl: Add superstep generators to the dsl
JDBetteridge Jun 30, 2025
90bfd21
examples: Add a 1D wave on a string superstep example
JDBetteridge Jun 30, 2025
9b6a7de
examples: Add a simple 2D superstep example
JDBetteridge Jun 30, 2025
6149d6f
examples: Add a 2D acoustic superstep example for layered and Marmousi
JDBetteridge Jun 30, 2025
0170cd4
misc: Init
JDBetteridge Jun 30, 2025
c159970
dsl: Add superstep_solution_transfer function + tweaks
JDBetteridge Jul 3, 2025
80df5de
examples: Add the first draft of the superstepping notebook
JDBetteridge Jul 3, 2025
9ff9c43
lint: Fix my linter
JDBetteridge Jul 3, 2025
b799192
misc: Noodling
JDBetteridge Jul 14, 2025
098f202
WIP
JDBetteridge Jul 15, 2025
3a68616
misc: conflict?
JDBetteridge Jul 16, 2025
7cb4485
misc: PR comments part 1
JDBetteridge Aug 18, 2025
5e3b879
misc: Replace more triple quotes
JDBetteridge Aug 18, 2025
b2fd861
misc: Revert suggestion until fix found
JDBetteridge Aug 18, 2025
16bd5d8
examples: Refactor 1D and 2D code into one file
JDBetteridge Aug 18, 2025
e005e6d
examples: Update superstepping script to use the examples.model
JDBetteridge Aug 20, 2025
74ea6ea
misc: Typo
JDBetteridge Aug 20, 2025
ce39a89
examples: Review comments on notebook
JDBetteridge Aug 20, 2025
b7d4da6
misc: Add more verbose comment to non-superstepping solution transfer…
JDBetteridge Aug 20, 2025
992abcc
examples: Tidy up 1D/2D example
JDBetteridge Aug 20, 2025
11a1d66
ci: Add superstepping to the CI
JDBetteridge Aug 20, 2025
a2f478c
misc: Raise exception for time_order not equal 2
JDBetteridge Aug 20, 2025
f0b7d99
misc: Remove dead code
JDBetteridge Aug 20, 2025
83d8783
misc: Add a paper reference to the superstepping module
JDBetteridge Aug 20, 2025
e8496a9
examples: Tidy superstepping notebook
JDBetteridge Aug 20, 2025
75eeaf5
misc: Rename example filename
JDBetteridge Sep 9, 2025
392e6da
ci: Grab correct source from data repo
JDBetteridge Sep 10, 2025
a43feb7
dsl: Remove iterative superstep generator
JDBetteridge Sep 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ jobs:
- name: Checkout devito
uses: actions/checkout@v5

- name: Checkout data
uses: actions/checkout@v5
with:
repository: 'devitocodes/data'
path: 'data'

- name: Setup conda
uses: conda-incubator/setup-miniconda@v3
with:
Expand Down Expand Up @@ -81,6 +87,13 @@ jobs:
run: |
python examples/cfd/example_diffusion.py

- name: Timestepping examples
run: |
python examples/timestepping/ic_superstep.py -d 1
python examples/timestepping/ic_superstep.py -d 2
python examples/timestepping/acoustic_superstep.py --model layered
python examples/timestepping/acoustic_superstep.py --model marmousi

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,7 @@ jobs:
- name: ABC Notebooks
run: |
${{ env.RUN_CMD }} py.test --nbval examples/seismic/abc_methods

- name: Timestepping Notebooks
run: |
${{ env.RUN_CMD }} py.test --nbval examples/timestepping
Empty file added devito/timestepping/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions devito/timestepping/superstep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" This module implements superstepping.
This is a timestepping scheme for advancing a solution multiple timesteps
at once.

The method employed here takes inspiration from the following paper:
- Nemeth, T et al. (2025): Superstep wavefield propagation
"""
from devito.types import Eq, Function, TimeFunction


def superstep_generator(field, stencil, k, nt=0):
"""
Generate superstep using a binary decomposition:
A^k = aⱼ A^2ʲ × ... × a₂ A^2² × a₁ A² × a₀ A
where k = aⱼ·2ʲ + ... + a₂·2² + a₁·2¹ + a₀·2⁰
"""
# New fields, for vector formulation both current and previous timestep are needed
name = field.name
grid = field.grid
# time_order of `field` needs to be 2
if field.time_order != 2:
raise ValueError(
'Superstepping is currently only supports `time_order=2`'
)
u = TimeFunction(
name=f'{name}_ss',
grid=grid,
time_order=field.time_order,
space_order=2*k
)
u_prev = TimeFunction(
name=f'{name}_ss_p',
grid=grid,
time_order=field.time_order,
space_order=2*k
)

superstep_solution_transfer(field, u, u_prev, nt)

# Substitute new fields into stencil
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
ss_stencil = ss_stencil.expand().expand(add=True, nest=True)

# Binary decomposition algorithm (see docstring):
# Calculate the binary decomposition of the exponent (k) and accumulate the
# resultant operator
current = (ss_stencil, u)
q, r = divmod(k, 2)
accumulate = current if r else (1, 1)
while q:
q, r = divmod(q, 2)
current = _combine_superstep(current, current, u, u_prev, k)
if r:
accumulate = _combine_superstep(accumulate, current, u, u_prev, k)

return u, u_prev, Eq(u.forward, accumulate[0]), Eq(u_prev.forward, accumulate[1])


def superstep_solution_transfer(old, new, new_p, nt):
"""
Transfer state from a previous TimeFunction to a 2 field superstep
Used after injecting source using standard timestepping.
"""
# This method is completely generic for future development, but currently
# only time_order == 2 is implemented!
idx = nt % (old.time_order + 1) if old.save is None else -1
for ii in range(old.time_order):
new.data[ii, :] = old.data[idx - ii - 1]
new_p.data[ii, :] = old.data[idx - ii - 2]


def _combine_superstep(stencil_a, stencil_b, u, u_prev, k):
"""
Combine two arbitrary order supersteps
"""
# Placeholder fields for forming the superstep
grid = u.grid
# Can I use a TempFunction here?
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these Functions just used as symbolic objects? If so, why not just use Symbol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work, trying TempFunction although I have no idea if it's the correct type to try

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TempFunction also doesn't work, I'm honestly not sure what using a different object gains you here. These temporaries are never used, so should never be allocated any memory. Unless I'm misunderstanding something?

b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k)

new = []
if stencil_a == (1, 1):
new = stencil_b
else:
for stencil in stencil_a:
new_stencil = stencil.subs({u: a_tmp, u_prev: b_tmp}, postprocess=False)
new_stencil = new_stencil.subs(
{a_tmp: stencil_b[0], b_tmp: stencil_b[1]}, postprocess=False
)
new_stencil = new_stencil.expand().expand(add=True, nest=True)
new.append(new_stencil)

return new
234 changes: 234 additions & 0 deletions examples/timestepping/acoustic_superstep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
""" Script that demonstrates the functionality of the superstep in 2D
Acoustic wave equation with source injection
"""
import os
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import numpy as np

from devito import (
ConditionalDimension,
Eq,
Operator,
SparseTimeFunction,
TimeFunction,
solve,
)
from devito.timestepping.superstep import superstep_generator
from examples.seismic import demo_model, SeismicModel


def ricker(t, f=10, A=1):
"""
The Ricker wavelet
f - freq in Hz
A - amplitude
"""
trm = (np.pi * f * (t - 1 / f)) ** 2
return A * (1 - 2 * trm) * np.exp(-trm)


def acoustic_model(model, t0, t1, t2, critical_dt, source, step=1, snapshots=1):
# Construct 2D Grid
x, y = model.grid.dimensions
velocity = model.vp
u = TimeFunction(name="u", grid=model.grid, time_order=2, space_order=2)

pde = (1/velocity**2)*u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward))

nt1 = int(np.ceil((t1 - t0)/critical_dt))
dt = (t1 - t0)/nt1

# Source
t = np.linspace(t0, t1, nt1)
rick = ricker(t)
source = SparseTimeFunction(
name="ricker",
npoint=1,
coordinates=[source],
nt=nt1,
grid=model.grid,
time_order=2,
space_order=4
)
source.data[:, 0] = rick
src_term = source.inject(
field=u.forward,
expr=source*velocity**2*dt**2
)

op1 = Operator([stencil] + src_term)
op1(time=nt1 - 1, dt=dt)

# Stencil and operator
idx = nt1 % 3
if step == 1:
# Non-superstep case
# In this case we need to create a new `TimeFunction` and copy
# the previous soluton into that new function. This is necessary
# when a rotating buffer is used in the `TimeFunction` and the
# order of the timesteps is not necessarily the right order for
# resuming the simulation. We also create a new stencil that
# writes to the new `TimeFunction`.
new_u = TimeFunction(
name="new_u",
grid=model.grid,
time_order=2,
space_order=2
)
stencil = [stencil.subs(
{u.forward: new_u.forward, u: new_u, u.backward: new_u.backward}
)]
new_u.data[0, :] = u.data[idx - 2]
new_u.data[1, :] = u.data[idx - 1]
new_u.data[2, :] = u.data[idx]
else:
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step, nt=nt1)

nt2 = int(np.ceil((t2 - t1)/critical_dt))
dt = (t2 - t1)/nt2

# Snapshot the solution
factor = int(np.ceil(nt2/(snapshots + 1)))
t_sub = ConditionalDimension(
't_sub',
parent=model.grid.time_dim,
factor=factor
)
u_save = TimeFunction(
name='usave',
grid=model.grid,
time_order=0,
space_order=2,
save=snapshots//step + 1,
time_dim=t_sub
)
save = Eq(u_save, new_u)

op = Operator([*stencil, save])
op(dt=dt)

if step == 1:
u_save.data[0, :, :] = u.data[idx]

return u_save.data


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--model', default='layered', choices=['layered', 'marmousi'])
args = parser.parse_args()

t0 = 0
t1 = 0.2
if args.model == 'layered':
source = (500, 20)
t2 = 0.65
critical_dt = 0.002357
zlim = 30
else: # Marmousi
# This requires the `devitocodes/data` repository, which we
# assume to be checked out at `$VIRTUAL_ENV/src/data`.
source = (1500, 1500)
t2 = 0.5
critical_dt = 0.0013728
zlim = 20
try:
path = f'{os.environ["VIRTUAL_ENV"]}/src'
except KeyError:
path = str(os.environ['GITHUB_WORKSPACE'])
tmp_model = demo_model(
'marmousi-isotropic',
space_order=2,
data_path=f'{path}/data',
nbl=0
)
cropped = tmp_model.vp.data[400:701, -321:-20]

# Supersteps
k = [1, 4]
# Snapshots
m = 13
fig, axes = plt.subplots(len(k), m)

for step, ax_row in zip(k, axes, strict=True):
# Redefine the model every iteration because we need to adjust
# the space order
if args.model == 'layered':
model = demo_model(
'layers-isotropic',
space_order=(2, step, step),
nlayers=4,
vp_top=1500,
vp_bottom=3000,
nbl=0
)
else: # Marmousi
model = SeismicModel(
space_order=(2, step, step),
vp=1000*cropped,
nbl=0,
origin=(0, 0),
shape=cropped.shape,
spacing=(10, 10)
)

plot_extent = [
model.origin[0],
model.origin[0] + model.grid.extent[0],
model.origin[1] + model.grid.extent[1],
model.origin[1]
]
data = acoustic_model(
model, t0, t1, t2, critical_dt, source, step=step, snapshots=m
)
time = np.linspace(t1, t2, (m - 1)//step + 1)
idx = 0
for ii, ax in enumerate(ax_row):
if ii % step == 0:
ax.imshow(
data[idx, :, :].T,
extent=plot_extent,
vmin=-zlim, vmax=zlim,
cmap='seismic'
)
ax.imshow(model.vp.data.T, cmap='grey', extent=plot_extent, alpha=0.2)
ax.set_title(f't={time[idx]:0.3f}')
idx += 1
if ii > 0:
ax.set_xticklabels([])
ax.set_yticklabels([])
else:
xticks = ax.get_xticks()
ax.set_xticks(np.array((
model.origin[0],
model.origin[0] + model.grid.extent[0]
)))
ax.set_xlim(
model.origin[0],
model.origin[0] + model.grid.extent[0]
)
yticks = ax.get_yticks()
ax.set_yticks(np.array((
model.origin[1],
model.origin[1] + model.grid.extent[1]
)))
ax.set_ylim(
model.origin[1] + model.grid.extent[1],
model.origin[1]
)
else:
ax.remove()

fig.set_size_inches(16, 3.5)
fig.subplots_adjust(
left=0.05,
bottom=0.025,
right=0.99,
top=0.97,
wspace=0.06,
hspace=0.06
)
fig.savefig(f'{args.model}.png', dpi=300)
Loading
Loading