-
Notifications
You must be signed in to change notification settings - Fork 238
dsl: Superstep #2658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
dsl: Superstep #2658
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
044b14e
dsl: Add superstep generators to the dsl
JDBetteridge 90bfd21
examples: Add a 1D wave on a string superstep example
JDBetteridge 9b6a7de
examples: Add a simple 2D superstep example
JDBetteridge 6149d6f
examples: Add a 2D acoustic superstep example for layered and Marmousi
JDBetteridge 0170cd4
misc: Init
JDBetteridge c159970
dsl: Add superstep_solution_transfer function + tweaks
JDBetteridge 80df5de
examples: Add the first draft of the superstepping notebook
JDBetteridge 9ff9c43
lint: Fix my linter
JDBetteridge b799192
misc: Noodling
JDBetteridge 098f202
WIP
JDBetteridge 3a68616
misc: conflict?
JDBetteridge 7cb4485
misc: PR comments part 1
JDBetteridge 5e3b879
misc: Replace more triple quotes
JDBetteridge b2fd861
misc: Revert suggestion until fix found
JDBetteridge 16bd5d8
examples: Refactor 1D and 2D code into one file
JDBetteridge e005e6d
examples: Update superstepping script to use the examples.model
JDBetteridge 74ea6ea
misc: Typo
JDBetteridge ce39a89
examples: Review comments on notebook
JDBetteridge b7d4da6
misc: Add more verbose comment to non-superstepping solution transfer…
JDBetteridge 992abcc
examples: Tidy up 1D/2D example
JDBetteridge 11a1d66
ci: Add superstepping to the CI
JDBetteridge a2f478c
misc: Raise exception for time_order not equal 2
JDBetteridge f0b7d99
misc: Remove dead code
JDBetteridge 83d8783
misc: Add a paper reference to the superstepping module
JDBetteridge e8496a9
examples: Tidy superstepping notebook
JDBetteridge 75eeaf5
misc: Rename example filename
JDBetteridge 392e6da
ci: Grab correct source from data repo
JDBetteridge a43feb7
dsl: Remove iterative superstep generator
JDBetteridge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" This module implements superstepping. | ||
This is a timestepping scheme for advancing a solution multiple timesteps | ||
at once. | ||
|
||
The method employed here takes inspiration from the following paper: | ||
- Nemeth, T et al. (2025): Superstep wavefield propagation | ||
""" | ||
from devito.types import Eq, Function, TimeFunction | ||
|
||
|
||
def superstep_generator(field, stencil, k, nt=0): | ||
""" | ||
Generate superstep using a binary decomposition: | ||
A^k = aⱼ A^2ʲ × ... × a₂ A^2² × a₁ A² × a₀ A | ||
where k = aⱼ·2ʲ + ... + a₂·2² + a₁·2¹ + a₀·2⁰ | ||
""" | ||
# New fields, for vector formulation both current and previous timestep are needed | ||
name = field.name | ||
grid = field.grid | ||
# time_order of `field` needs to be 2 | ||
if field.time_order != 2: | ||
raise ValueError( | ||
'Superstepping is currently only supports `time_order=2`' | ||
) | ||
u = TimeFunction( | ||
name=f'{name}_ss', | ||
grid=grid, | ||
time_order=field.time_order, | ||
space_order=2*k | ||
) | ||
u_prev = TimeFunction( | ||
name=f'{name}_ss_p', | ||
grid=grid, | ||
time_order=field.time_order, | ||
space_order=2*k | ||
) | ||
|
||
superstep_solution_transfer(field, u, u_prev, nt) | ||
|
||
# Substitute new fields into stencil | ||
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False) | ||
ss_stencil = ss_stencil.expand().expand(add=True, nest=True) | ||
|
||
# Binary decomposition algorithm (see docstring): | ||
# Calculate the binary decomposition of the exponent (k) and accumulate the | ||
# resultant operator | ||
current = (ss_stencil, u) | ||
q, r = divmod(k, 2) | ||
accumulate = current if r else (1, 1) | ||
while q: | ||
q, r = divmod(q, 2) | ||
current = _combine_superstep(current, current, u, u_prev, k) | ||
if r: | ||
accumulate = _combine_superstep(accumulate, current, u, u_prev, k) | ||
|
||
return u, u_prev, Eq(u.forward, accumulate[0]), Eq(u_prev.forward, accumulate[1]) | ||
|
||
|
||
def superstep_solution_transfer(old, new, new_p, nt): | ||
""" | ||
Transfer state from a previous TimeFunction to a 2 field superstep | ||
Used after injecting source using standard timestepping. | ||
""" | ||
# This method is completely generic for future development, but currently | ||
# only time_order == 2 is implemented! | ||
idx = nt % (old.time_order + 1) if old.save is None else -1 | ||
for ii in range(old.time_order): | ||
new.data[ii, :] = old.data[idx - ii - 1] | ||
new_p.data[ii, :] = old.data[idx - ii - 2] | ||
|
||
|
||
def _combine_superstep(stencil_a, stencil_b, u, u_prev, k): | ||
""" | ||
Combine two arbitrary order supersteps | ||
""" | ||
# Placeholder fields for forming the superstep | ||
grid = u.grid | ||
# Can I use a TempFunction here? | ||
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k) | ||
b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k) | ||
|
||
new = [] | ||
if stencil_a == (1, 1): | ||
new = stencil_b | ||
else: | ||
for stencil in stencil_a: | ||
new_stencil = stencil.subs({u: a_tmp, u_prev: b_tmp}, postprocess=False) | ||
new_stencil = new_stencil.subs( | ||
{a_tmp: stencil_b[0], b_tmp: stencil_b[1]}, postprocess=False | ||
) | ||
new_stencil = new_stencil.expand().expand(add=True, nest=True) | ||
JDBetteridge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new.append(new_stencil) | ||
|
||
return new |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
""" Script that demonstrates the functionality of the superstep in 2D | ||
Acoustic wave equation with source injection | ||
""" | ||
import os | ||
from argparse import ArgumentParser | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from devito import ( | ||
ConditionalDimension, | ||
Eq, | ||
Operator, | ||
SparseTimeFunction, | ||
TimeFunction, | ||
solve, | ||
) | ||
from devito.timestepping.superstep import superstep_generator | ||
from examples.seismic import demo_model, SeismicModel | ||
|
||
|
||
def ricker(t, f=10, A=1): | ||
""" | ||
The Ricker wavelet | ||
f - freq in Hz | ||
A - amplitude | ||
""" | ||
trm = (np.pi * f * (t - 1 / f)) ** 2 | ||
return A * (1 - 2 * trm) * np.exp(-trm) | ||
|
||
|
||
def acoustic_model(model, t0, t1, t2, critical_dt, source, step=1, snapshots=1): | ||
# Construct 2D Grid | ||
x, y = model.grid.dimensions | ||
velocity = model.vp | ||
u = TimeFunction(name="u", grid=model.grid, time_order=2, space_order=2) | ||
|
||
pde = (1/velocity**2)*u.dt2 - u.laplace | ||
stencil = Eq(u.forward, solve(pde, u.forward)) | ||
|
||
nt1 = int(np.ceil((t1 - t0)/critical_dt)) | ||
dt = (t1 - t0)/nt1 | ||
|
||
# Source | ||
t = np.linspace(t0, t1, nt1) | ||
rick = ricker(t) | ||
source = SparseTimeFunction( | ||
name="ricker", | ||
npoint=1, | ||
coordinates=[source], | ||
nt=nt1, | ||
grid=model.grid, | ||
time_order=2, | ||
space_order=4 | ||
) | ||
source.data[:, 0] = rick | ||
src_term = source.inject( | ||
field=u.forward, | ||
expr=source*velocity**2*dt**2 | ||
) | ||
|
||
op1 = Operator([stencil] + src_term) | ||
op1(time=nt1 - 1, dt=dt) | ||
|
||
# Stencil and operator | ||
idx = nt1 % 3 | ||
if step == 1: | ||
# Non-superstep case | ||
# In this case we need to create a new `TimeFunction` and copy | ||
# the previous soluton into that new function. This is necessary | ||
# when a rotating buffer is used in the `TimeFunction` and the | ||
# order of the timesteps is not necessarily the right order for | ||
# resuming the simulation. We also create a new stencil that | ||
# writes to the new `TimeFunction`. | ||
new_u = TimeFunction( | ||
name="new_u", | ||
grid=model.grid, | ||
time_order=2, | ||
space_order=2 | ||
) | ||
stencil = [stencil.subs( | ||
{u.forward: new_u.forward, u: new_u, u.backward: new_u.backward} | ||
)] | ||
new_u.data[0, :] = u.data[idx - 2] | ||
new_u.data[1, :] = u.data[idx - 1] | ||
new_u.data[2, :] = u.data[idx] | ||
else: | ||
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step, nt=nt1) | ||
|
||
nt2 = int(np.ceil((t2 - t1)/critical_dt)) | ||
dt = (t2 - t1)/nt2 | ||
|
||
# Snapshot the solution | ||
factor = int(np.ceil(nt2/(snapshots + 1))) | ||
t_sub = ConditionalDimension( | ||
't_sub', | ||
parent=model.grid.time_dim, | ||
factor=factor | ||
) | ||
u_save = TimeFunction( | ||
name='usave', | ||
grid=model.grid, | ||
time_order=0, | ||
space_order=2, | ||
save=snapshots//step + 1, | ||
time_dim=t_sub | ||
) | ||
save = Eq(u_save, new_u) | ||
|
||
op = Operator([*stencil, save]) | ||
op(dt=dt) | ||
|
||
if step == 1: | ||
u_save.data[0, :, :] = u.data[idx] | ||
|
||
return u_save.data | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
parser.add_argument('--model', default='layered', choices=['layered', 'marmousi']) | ||
args = parser.parse_args() | ||
|
||
t0 = 0 | ||
t1 = 0.2 | ||
if args.model == 'layered': | ||
source = (500, 20) | ||
t2 = 0.65 | ||
critical_dt = 0.002357 | ||
zlim = 30 | ||
else: # Marmousi | ||
# This requires the `devitocodes/data` repository, which we | ||
# assume to be checked out at `$VIRTUAL_ENV/src/data`. | ||
source = (1500, 1500) | ||
t2 = 0.5 | ||
critical_dt = 0.0013728 | ||
zlim = 20 | ||
try: | ||
path = f'{os.environ["VIRTUAL_ENV"]}/src' | ||
except KeyError: | ||
path = str(os.environ['GITHUB_WORKSPACE']) | ||
tmp_model = demo_model( | ||
'marmousi-isotropic', | ||
space_order=2, | ||
data_path=f'{path}/data', | ||
nbl=0 | ||
) | ||
cropped = tmp_model.vp.data[400:701, -321:-20] | ||
|
||
# Supersteps | ||
k = [1, 4] | ||
# Snapshots | ||
m = 13 | ||
fig, axes = plt.subplots(len(k), m) | ||
|
||
for step, ax_row in zip(k, axes, strict=True): | ||
# Redefine the model every iteration because we need to adjust | ||
# the space order | ||
if args.model == 'layered': | ||
model = demo_model( | ||
'layers-isotropic', | ||
space_order=(2, step, step), | ||
nlayers=4, | ||
vp_top=1500, | ||
vp_bottom=3000, | ||
nbl=0 | ||
) | ||
else: # Marmousi | ||
model = SeismicModel( | ||
space_order=(2, step, step), | ||
vp=1000*cropped, | ||
nbl=0, | ||
origin=(0, 0), | ||
shape=cropped.shape, | ||
spacing=(10, 10) | ||
) | ||
|
||
plot_extent = [ | ||
model.origin[0], | ||
model.origin[0] + model.grid.extent[0], | ||
model.origin[1] + model.grid.extent[1], | ||
model.origin[1] | ||
] | ||
data = acoustic_model( | ||
model, t0, t1, t2, critical_dt, source, step=step, snapshots=m | ||
) | ||
time = np.linspace(t1, t2, (m - 1)//step + 1) | ||
idx = 0 | ||
for ii, ax in enumerate(ax_row): | ||
if ii % step == 0: | ||
ax.imshow( | ||
data[idx, :, :].T, | ||
extent=plot_extent, | ||
vmin=-zlim, vmax=zlim, | ||
cmap='seismic' | ||
) | ||
ax.imshow(model.vp.data.T, cmap='grey', extent=plot_extent, alpha=0.2) | ||
ax.set_title(f't={time[idx]:0.3f}') | ||
idx += 1 | ||
if ii > 0: | ||
ax.set_xticklabels([]) | ||
ax.set_yticklabels([]) | ||
else: | ||
xticks = ax.get_xticks() | ||
ax.set_xticks(np.array(( | ||
model.origin[0], | ||
model.origin[0] + model.grid.extent[0] | ||
))) | ||
ax.set_xlim( | ||
model.origin[0], | ||
model.origin[0] + model.grid.extent[0] | ||
) | ||
yticks = ax.get_yticks() | ||
ax.set_yticks(np.array(( | ||
model.origin[1], | ||
model.origin[1] + model.grid.extent[1] | ||
))) | ||
ax.set_ylim( | ||
model.origin[1] + model.grid.extent[1], | ||
model.origin[1] | ||
) | ||
else: | ||
ax.remove() | ||
|
||
fig.set_size_inches(16, 3.5) | ||
fig.subplots_adjust( | ||
left=0.05, | ||
bottom=0.025, | ||
right=0.99, | ||
top=0.97, | ||
wspace=0.06, | ||
hspace=0.06 | ||
) | ||
fig.savefig(f'{args.model}.png', dpi=300) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these
Function
s just used as symbolic objects? If so, why not just useSymbol
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't work, trying
TempFunction
although I have no idea if it's the correct type to tryThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TempFunction
also doesn't work, I'm honestly not sure what using a different object gains you here. These temporaries are never used, so should never be allocated any memory. Unless I'm misunderstanding something?