Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/devito.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

109 changes: 109 additions & 0 deletions devito/operator/new_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from devito import Function, Eq
from devito.symbolics import uxreplace
from sympy import Basic


class MultiStage(Basic):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for this to subclass Eq?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would move all of the MultiStage API classes to devito.types.equation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you’re right — I’ve updated the class inheritance accordingly.

def __new__(cls, eq, method):
assert isinstance(eq, Eq)
return Basic.__new__(cls, eq, method)

@property
def eq(self):
return self.args[0]

@property
def method(self):
return self.args[1]


class RK(Basic):
"""
A class representing an explicit Runge-Kutta method via its Butcher tableau.
Parameters
----------
a : list[list[float]]
Lower-triangular coefficient matrix (stage dependencies).
b : list[float]
Weights for the final combination step.
c : list[float]
Weights for the stages time step.
"""

def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
self.s = len(b) # number of stages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: two spaces before inline comment, and should start with a capital letter

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks and done


self._validate()

def _validate(self):
assert len(self.a) == self.s, "'a' must have s rows"
for i, row in enumerate(self.a):
assert len(row) == i, f"Row {i} in 'a' must have {i} entries for explicit RK"

def expand_stages(self, base_eq, eq_num=0):
"""
Expand a single Eq into a list of stage-wise Eqs for this RK method.
Parameters
----------
base_eq : Eq
The equation Eq(u.forward, rhs) to be expanded into RK stages.
eq_number : integer, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: double-check docstrings, there are some inconsistencies and typos in here

Copy link
Author

@fernanvr fernanvr Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I’ve resolved it, but let me know if anything still looks off.

The equation number to idetify the k_i's stages
Returns
-------
list of Eq
Stage-wise equations: [k0=..., k1=..., ..., u.forward=...]
"""
u = base_eq.lhs.function
rhs = base_eq.rhs
grid = u.grid
dt = grid.stepping_dim.spacing
t = grid.time_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t = grid.time_dim
dt = t.spacing

would be a little neater

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciated and done


# Create temporary Functions to hold each stage
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to use SymbolRegistry to prevent clashes when creating temporary function names.

Given that these Functions are instantiated during compilation rather than by the user, Array might be sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think these need to be Array objects (or a subclass). Since they are Function objects, they will appear as arguments to the main Kernel, which I don't think is necessary

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part was tricky, I still haven’t figured it out. I left a commented line where I tried defining the k functions as an Array, but it triggered an error in the Operator construction. I’m likely missing something.

for i in range(self.s)]

stage_eqs = []

# Build each stage
for i in range(self.s):
u_temp = u
for j in range(i):
if self.a[i][j] != 0:
u_temp += self.a[i][j] * dt * k[j]
t_shift = t + self.c[i] * dt

# Evaluate RHS at intermediate value
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift})
stage_eqs.append(Eq(k[i], stage_rhs))

# Final update: u.forward = u + dt * sum(bᵢ * kᵢ)
u_next = u
for i in range(self.s):
u_next += self.b[i] * dt * k[i]
stage_eqs.append(Eq(u.forward, u_next))

return stage_eqs

# ---- Named methods for convenience ----
@classmethod
def RK44(cls):
"""Classical Runge-Kutta of 4 stages and 4th order"""
a = [
[],
[1 / 2],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: 1/2 rather than 1 / 2 would improve readability

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciated and done

[0, 1 / 2],
[0, 0, 1]
]
b = [1 / 6, 1 / 3, 1 / 3, 1 / 6]
c = [0, 1 / 2, 1 / 2, 1]
return cls(a, b, c)


30 changes: 28 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
disk_layer)
from devito.types.dimension import Thickness

from devito.operator.new_classes import MultiStage

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

__all__ = ['Operator']

Expand Down Expand Up @@ -184,8 +185,9 @@ def _sanitize_exprs(cls, expressions, **kwargs):
expressions = as_tuple(expressions)

for i in expressions:
if not isinstance(i, Evaluable):
raise CompilationError(f"`{i!s}` is not an Evaluable object; "
i_check = i.eq if isinstance(i, MultiStage) else i
if not isinstance(i_check, Evaluable):
raise CompilationError(f"`{i_check!s}` is not an Evaluable object; "
"check your equation again")

return expressions
Expand Down Expand Up @@ -271,6 +273,9 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# [MultiStage] -> [Eqs]
expressions = cls._lower_multistage(expressions, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be called inside _lower_dsl, which in turn is called within _lower_exprs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to be called within _lower_exprs, but I couldn't find the _lower_dsl function...


# [Eq] -> [LoweredEq]
expressions = cls._lower_exprs(expressions, **kwargs)

Expand Down Expand Up @@ -314,6 +319,27 @@ def _specialize_exprs(cls, expressions, **kwargs):
"""
return expressions

@classmethod
@timed_pass(name='lowering.MultiStages')
def _lower_multistage(cls, expressions, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True MultiStage lowering should probably be pulled out into devito.ir.equations.algorithms

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Separating the multi-stage time-integrator scheme in stages:

* Check if the time-integrator is Multistage;
* Creating the stages of the method.
"""

lowered = []
for i, eq in enumerate(as_tuple(expressions)):
if isinstance(eq, MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would using singledispatch here aid extensibility going forward? See concretize_subdims for an example

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's done now.

time_int = eq.method
stage_eqs = time_int.expand_stages(eq.eq, eq_num=i)
lowered.extend(stage_eqs)
else:
lowered.append(eq)

return lowered

@classmethod
@timed_pass(name='lowering.Expressions')
def _lower_exprs(cls, expressions, **kwargs):
Expand Down
79 changes: 79 additions & 0 deletions tests/test_multistage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import sympy as sym
import matplotlib.pyplot as plt

import devito as dv
from examples.seismic import Receiver, TimeAxis

from devito.operator.new_classes import RK, MultiStage
# Set logging level for debugging
dv.configuration['log-level'] = 'DEBUG'

# Parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wants to use pytest going forward. If you look at the existing tests, you will see that a lot of them are very granular, which aids bugfixing and test coverage.

It's worth noting that tests (especially the very elementary ones) can be physically nonsensical, in order to check that some aspect of the API or compiler functions as intended. For example, it would be useful to test that a range of Butcher tableaus get correctly assembled into the corresponding timestepping schemes, even if those timestepping schemes themselves are meaningless.

Examples of simplifications used for unit tests include using 1D where possible and using trivial Eqs (i.e. Eq(f, 1))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests that we include for essentially every new class include ensuring that it can be pickled (see test_pickle.py) and that it can be rebuilt correctly using its _rebuild method (also used for pickling). This is what the __rargs__ and __rkwargs__ attached to various classes are for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve been working in that direction, but I think it’s still a bit rough

space_order = 2
fd_order = 2
extent = (1000, 1000)
shape = (201, 201)
origin = (0, 0)

# Grid setup
grid = dv.Grid(origin=origin, extent=extent, shape=shape, dtype=np.float64)
x, y = grid.dimensions
dt = grid.stepping_dim.spacing
t = grid.time_dim
dx = extent[0] / (shape[0] - 1)

# Medium velocity model
vel = dv.Function(name="vel", grid=grid, space_order=space_order, dtype=np.float64)
vel.data[:] = 1.0
vel.data[150:, :] = 1.3

# Define wavefield unknowns: u (displacement) and v (velocity)
fun_labels = ['u', 'v']
U = [dv.TimeFunction(name=name, grid=grid, space_order=space_order,
time_order=1, dtype=np.float64) for name in fun_labels]

# Time axis
t0, tn = 0.0, 500.0
dt0 = np.max(vel.data) / dx**2
nt = int((tn - t0) / dt0)
dt0 = tn / nt
time_range = TimeAxis(start=t0, stop=tn, num=nt + 1)

# Receiver setup
rec = Receiver(name='rec', grid=grid, npoint=3, time_range=time_range)
rec.coordinates.data[:, 0] = np.linspace(0, 1, 3)
rec.coordinates.data[:, 1] = 0.5
rec = rec.interpolate(expr=U[0].forward)

# Source definition
src_spatial = dv.Function(name="src_spat", grid=grid, space_order=space_order, dtype=np.float64)
src_spatial.data[100, 100] = 1 / dx**2

f0 = 0.01
src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1/f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1/f0))**2)

# PDE system (2D acoustic)
system_eqs = [U[1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a simple heat equation would be a good example/test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I incorporated for the tests

(dv.Derivative(U[0], (x, 2), fd_order=fd_order) +
dv.Derivative(U[0], (y, 2), fd_order=fd_order) +
src_spatial * src_temporal) * vel**2]

# Time integration scheme
rk = RK.RK44()

# MultiStage object
pdes = [MultiStage(dv.Eq(U[i], system_eqs[i]), rk) for i in range(2)]

# Construct and run operator
op = dv.Operator(pdes + [rec], subs=grid.spacing_map)
op(dt=dt0, time=nt)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the generated C code, I noticed the following:

        k00[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k01[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k02[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k03[x + 2][y + 2] = v[t0][x + 2][y + 2];

Is this intended? It doesn't seem optimal since they are all assigned the same value

# Plot final wavefield
plt.imshow(U[0].data[1, :], cmap="seismic")
plt.colorbar(label="Amplitude")
plt.title("Wavefield snapshot (t = final)")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.show()