Skip to content

Commit 0d90d06

Browse files
authored
Merge pull request #2658 from devitocodes/JDBetteridge/superstep
dsl: Superstep
2 parents b90bb2b + a43feb7 commit 0d90d06

File tree

7 files changed

+1424
-0
lines changed

7 files changed

+1424
-0
lines changed

.github/workflows/examples.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ jobs:
3636
- name: Checkout devito
3737
uses: actions/checkout@v5
3838

39+
- name: Checkout data
40+
uses: actions/checkout@v5
41+
with:
42+
repository: 'devitocodes/data'
43+
path: 'data'
44+
3945
- name: Setup conda
4046
uses: conda-incubator/setup-miniconda@v3
4147
with:
@@ -81,6 +87,13 @@ jobs:
8187
run: |
8288
python examples/cfd/example_diffusion.py
8389
90+
- name: Timestepping examples
91+
run: |
92+
python examples/timestepping/ic_superstep.py -d 1
93+
python examples/timestepping/ic_superstep.py -d 2
94+
python examples/timestepping/acoustic_superstep.py --model layered
95+
python examples/timestepping/acoustic_superstep.py --model marmousi
96+
8497
- name: Upload coverage to Codecov
8598
uses: codecov/codecov-action@v5
8699
with:

.github/workflows/tutorials.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,7 @@ jobs:
142142
- name: ABC Notebooks
143143
run: |
144144
${{ env.RUN_CMD }} py.test --nbval examples/seismic/abc_methods
145+
146+
- name: Timestepping Notebooks
147+
run: |
148+
${{ env.RUN_CMD }} py.test --nbval examples/timestepping

devito/timestepping/__init__.py

Whitespace-only changes.

devito/timestepping/superstep.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
""" This module implements superstepping.
2+
This is a timestepping scheme for advancing a solution multiple timesteps
3+
at once.
4+
5+
The method employed here takes inspiration from the following paper:
6+
- Nemeth, T et al. (2025): Superstep wavefield propagation
7+
"""
8+
from devito.types import Eq, Function, TimeFunction
9+
10+
11+
def superstep_generator(field, stencil, k, nt=0):
12+
"""
13+
Generate superstep using a binary decomposition:
14+
A^k = aⱼ A^2ʲ × ... × a₂ A^2² × a₁ A² × a₀ A
15+
where k = aⱼ·2ʲ + ... + a₂·2² + a₁·2¹ + a₀·2⁰
16+
"""
17+
# New fields, for vector formulation both current and previous timestep are needed
18+
name = field.name
19+
grid = field.grid
20+
# time_order of `field` needs to be 2
21+
if field.time_order != 2:
22+
raise ValueError(
23+
'Superstepping is currently only supports `time_order=2`'
24+
)
25+
u = TimeFunction(
26+
name=f'{name}_ss',
27+
grid=grid,
28+
time_order=field.time_order,
29+
space_order=2*k
30+
)
31+
u_prev = TimeFunction(
32+
name=f'{name}_ss_p',
33+
grid=grid,
34+
time_order=field.time_order,
35+
space_order=2*k
36+
)
37+
38+
superstep_solution_transfer(field, u, u_prev, nt)
39+
40+
# Substitute new fields into stencil
41+
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
42+
ss_stencil = ss_stencil.expand().expand(add=True, nest=True)
43+
44+
# Binary decomposition algorithm (see docstring):
45+
# Calculate the binary decomposition of the exponent (k) and accumulate the
46+
# resultant operator
47+
current = (ss_stencil, u)
48+
q, r = divmod(k, 2)
49+
accumulate = current if r else (1, 1)
50+
while q:
51+
q, r = divmod(q, 2)
52+
current = _combine_superstep(current, current, u, u_prev, k)
53+
if r:
54+
accumulate = _combine_superstep(accumulate, current, u, u_prev, k)
55+
56+
return u, u_prev, Eq(u.forward, accumulate[0]), Eq(u_prev.forward, accumulate[1])
57+
58+
59+
def superstep_solution_transfer(old, new, new_p, nt):
60+
"""
61+
Transfer state from a previous TimeFunction to a 2 field superstep
62+
Used after injecting source using standard timestepping.
63+
"""
64+
# This method is completely generic for future development, but currently
65+
# only time_order == 2 is implemented!
66+
idx = nt % (old.time_order + 1) if old.save is None else -1
67+
for ii in range(old.time_order):
68+
new.data[ii, :] = old.data[idx - ii - 1]
69+
new_p.data[ii, :] = old.data[idx - ii - 2]
70+
71+
72+
def _combine_superstep(stencil_a, stencil_b, u, u_prev, k):
73+
"""
74+
Combine two arbitrary order supersteps
75+
"""
76+
# Placeholder fields for forming the superstep
77+
grid = u.grid
78+
# Can I use a TempFunction here?
79+
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k)
80+
b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k)
81+
82+
new = []
83+
if stencil_a == (1, 1):
84+
new = stencil_b
85+
else:
86+
for stencil in stencil_a:
87+
new_stencil = stencil.subs({u: a_tmp, u_prev: b_tmp}, postprocess=False)
88+
new_stencil = new_stencil.subs(
89+
{a_tmp: stencil_b[0], b_tmp: stencil_b[1]}, postprocess=False
90+
)
91+
new_stencil = new_stencil.expand().expand(add=True, nest=True)
92+
new.append(new_stencil)
93+
94+
return new
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
""" Script that demonstrates the functionality of the superstep in 2D
2+
Acoustic wave equation with source injection
3+
"""
4+
import os
5+
from argparse import ArgumentParser
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
from devito import (
11+
ConditionalDimension,
12+
Eq,
13+
Operator,
14+
SparseTimeFunction,
15+
TimeFunction,
16+
solve,
17+
)
18+
from devito.timestepping.superstep import superstep_generator
19+
from examples.seismic import demo_model, SeismicModel
20+
21+
22+
def ricker(t, f=10, A=1):
23+
"""
24+
The Ricker wavelet
25+
f - freq in Hz
26+
A - amplitude
27+
"""
28+
trm = (np.pi * f * (t - 1 / f)) ** 2
29+
return A * (1 - 2 * trm) * np.exp(-trm)
30+
31+
32+
def acoustic_model(model, t0, t1, t2, critical_dt, source, step=1, snapshots=1):
33+
# Construct 2D Grid
34+
x, y = model.grid.dimensions
35+
velocity = model.vp
36+
u = TimeFunction(name="u", grid=model.grid, time_order=2, space_order=2)
37+
38+
pde = (1/velocity**2)*u.dt2 - u.laplace
39+
stencil = Eq(u.forward, solve(pde, u.forward))
40+
41+
nt1 = int(np.ceil((t1 - t0)/critical_dt))
42+
dt = (t1 - t0)/nt1
43+
44+
# Source
45+
t = np.linspace(t0, t1, nt1)
46+
rick = ricker(t)
47+
source = SparseTimeFunction(
48+
name="ricker",
49+
npoint=1,
50+
coordinates=[source],
51+
nt=nt1,
52+
grid=model.grid,
53+
time_order=2,
54+
space_order=4
55+
)
56+
source.data[:, 0] = rick
57+
src_term = source.inject(
58+
field=u.forward,
59+
expr=source*velocity**2*dt**2
60+
)
61+
62+
op1 = Operator([stencil] + src_term)
63+
op1(time=nt1 - 1, dt=dt)
64+
65+
# Stencil and operator
66+
idx = nt1 % 3
67+
if step == 1:
68+
# Non-superstep case
69+
# In this case we need to create a new `TimeFunction` and copy
70+
# the previous soluton into that new function. This is necessary
71+
# when a rotating buffer is used in the `TimeFunction` and the
72+
# order of the timesteps is not necessarily the right order for
73+
# resuming the simulation. We also create a new stencil that
74+
# writes to the new `TimeFunction`.
75+
new_u = TimeFunction(
76+
name="new_u",
77+
grid=model.grid,
78+
time_order=2,
79+
space_order=2
80+
)
81+
stencil = [stencil.subs(
82+
{u.forward: new_u.forward, u: new_u, u.backward: new_u.backward}
83+
)]
84+
new_u.data[0, :] = u.data[idx - 2]
85+
new_u.data[1, :] = u.data[idx - 1]
86+
new_u.data[2, :] = u.data[idx]
87+
else:
88+
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step, nt=nt1)
89+
90+
nt2 = int(np.ceil((t2 - t1)/critical_dt))
91+
dt = (t2 - t1)/nt2
92+
93+
# Snapshot the solution
94+
factor = int(np.ceil(nt2/(snapshots + 1)))
95+
t_sub = ConditionalDimension(
96+
't_sub',
97+
parent=model.grid.time_dim,
98+
factor=factor
99+
)
100+
u_save = TimeFunction(
101+
name='usave',
102+
grid=model.grid,
103+
time_order=0,
104+
space_order=2,
105+
save=snapshots//step + 1,
106+
time_dim=t_sub
107+
)
108+
save = Eq(u_save, new_u)
109+
110+
op = Operator([*stencil, save])
111+
op(dt=dt)
112+
113+
if step == 1:
114+
u_save.data[0, :, :] = u.data[idx]
115+
116+
return u_save.data
117+
118+
119+
if __name__ == '__main__':
120+
parser = ArgumentParser()
121+
parser.add_argument('--model', default='layered', choices=['layered', 'marmousi'])
122+
args = parser.parse_args()
123+
124+
t0 = 0
125+
t1 = 0.2
126+
if args.model == 'layered':
127+
source = (500, 20)
128+
t2 = 0.65
129+
critical_dt = 0.002357
130+
zlim = 30
131+
else: # Marmousi
132+
# This requires the `devitocodes/data` repository, which we
133+
# assume to be checked out at `$VIRTUAL_ENV/src/data`.
134+
source = (1500, 1500)
135+
t2 = 0.5
136+
critical_dt = 0.0013728
137+
zlim = 20
138+
try:
139+
path = f'{os.environ["VIRTUAL_ENV"]}/src'
140+
except KeyError:
141+
path = str(os.environ['GITHUB_WORKSPACE'])
142+
tmp_model = demo_model(
143+
'marmousi-isotropic',
144+
space_order=2,
145+
data_path=f'{path}/data',
146+
nbl=0
147+
)
148+
cropped = tmp_model.vp.data[400:701, -321:-20]
149+
150+
# Supersteps
151+
k = [1, 4]
152+
# Snapshots
153+
m = 13
154+
fig, axes = plt.subplots(len(k), m)
155+
156+
for step, ax_row in zip(k, axes, strict=True):
157+
# Redefine the model every iteration because we need to adjust
158+
# the space order
159+
if args.model == 'layered':
160+
model = demo_model(
161+
'layers-isotropic',
162+
space_order=(2, step, step),
163+
nlayers=4,
164+
vp_top=1500,
165+
vp_bottom=3000,
166+
nbl=0
167+
)
168+
else: # Marmousi
169+
model = SeismicModel(
170+
space_order=(2, step, step),
171+
vp=1000*cropped,
172+
nbl=0,
173+
origin=(0, 0),
174+
shape=cropped.shape,
175+
spacing=(10, 10)
176+
)
177+
178+
plot_extent = [
179+
model.origin[0],
180+
model.origin[0] + model.grid.extent[0],
181+
model.origin[1] + model.grid.extent[1],
182+
model.origin[1]
183+
]
184+
data = acoustic_model(
185+
model, t0, t1, t2, critical_dt, source, step=step, snapshots=m
186+
)
187+
time = np.linspace(t1, t2, (m - 1)//step + 1)
188+
idx = 0
189+
for ii, ax in enumerate(ax_row):
190+
if ii % step == 0:
191+
ax.imshow(
192+
data[idx, :, :].T,
193+
extent=plot_extent,
194+
vmin=-zlim, vmax=zlim,
195+
cmap='seismic'
196+
)
197+
ax.imshow(model.vp.data.T, cmap='grey', extent=plot_extent, alpha=0.2)
198+
ax.set_title(f't={time[idx]:0.3f}')
199+
idx += 1
200+
if ii > 0:
201+
ax.set_xticklabels([])
202+
ax.set_yticklabels([])
203+
else:
204+
xticks = ax.get_xticks()
205+
ax.set_xticks(np.array((
206+
model.origin[0],
207+
model.origin[0] + model.grid.extent[0]
208+
)))
209+
ax.set_xlim(
210+
model.origin[0],
211+
model.origin[0] + model.grid.extent[0]
212+
)
213+
yticks = ax.get_yticks()
214+
ax.set_yticks(np.array((
215+
model.origin[1],
216+
model.origin[1] + model.grid.extent[1]
217+
)))
218+
ax.set_ylim(
219+
model.origin[1] + model.grid.extent[1],
220+
model.origin[1]
221+
)
222+
else:
223+
ax.remove()
224+
225+
fig.set_size_inches(16, 3.5)
226+
fig.subplots_adjust(
227+
left=0.05,
228+
bottom=0.025,
229+
right=0.99,
230+
top=0.97,
231+
wspace=0.06,
232+
hspace=0.06
233+
)
234+
fig.savefig(f'{args.model}.png', dpi=300)

0 commit comments

Comments
 (0)