Skip to content

Conversation

fernanvr
Copy link

@fernanvr fernanvr commented May 5, 2025

Implementation of the ideas of @mloubout, as discussed in the Slack channel #timestepping.

It were created two new classes:

  • MultiStage: Each instance represents a combination of a PDE and its associated time integrator.

  • RK: Instances encapsulate a Butcher Tableau. The class methods define specific Runge-Kutta schemes, and it includes logic to expand a single PDE into multiple stage equations according to the RK method. It only contains the classic RK44, of fourth order and four stages, as a prove of concept.

To integrate the MultiStage into Devito’s pipeline, we modified the _lower(...) function in operator.py by adding the line:
expressions = cls._lower_multistage(expressions, **kwargs),
and modified _sanitize_exprs(cls, expressions, **kwargs) to recognize MultiStage instances.

We also created a new function:
_lower_multistage(),
in the same file. This function handles the expansion of MultiStage instances into their corresponding stage equations.

Copy link
Contributor

@FabioLuporini FabioLuporini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, many thanks for this!

As you may see from the "Files changed" tab (https://github.com/devitocodes/devito/pull/2599/files), there are several new files that shouldn't be there. The whole .idea folder has probably been pushed erroneously. You will have to rebase your git history and erase it.

The new file new_classes should have a proper name and be placed in a different folder, @EdCaunt may have some ideas about that.

The test you have added has to be placed within a function def test_.... otherwise it won't be picked up by pytest. And we are going to need more tests, one isn't enough.
Also, a test is a test, it shouldn't plot anything. Instead, the numerical output should be compared to either an analytical solution or to one or more precomputed ("by hand") norms.

@@ -271,6 +273,9 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# [MultiStage] -> [Eqs]
expressions = cls._lower_multistage(expressions, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be called inside _lower_dsl, which in turn is called within _lower_exprs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to be called within _lower_exprs, but I couldn't find the _lower_dsl function...

@FabioLuporini FabioLuporini added the API api (symbolics, types, ...) label May 6, 2025
@EdCaunt EdCaunt changed the title implementation of multi-stage time integrators #timestepping dsl: Introduce abstractions for multi-stage time integrators May 6, 2025
Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of comments, but looks like a good first attempt at multistage timestepping!

# Set logging level for debugging
dv.configuration['log-level'] = 'DEBUG'

# Parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wants to use pytest going forward. If you look at the existing tests, you will see that a lot of them are very granular, which aids bugfixing and test coverage.

It's worth noting that tests (especially the very elementary ones) can be physically nonsensical, in order to check that some aspect of the API or compiler functions as intended. For example, it would be useful to test that a range of Butcher tableaus get correctly assembled into the corresponding timestepping schemes, even if those timestepping schemes themselves are meaningless.

Examples of simplifications used for unit tests include using 1D where possible and using trivial Eqs (i.e. Eq(f, 1))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests that we include for essentially every new class include ensuring that it can be pickled (see test_pickle.py) and that it can be rebuilt correctly using its _rebuild method (also used for pickling). This is what the __rargs__ and __rkwargs__ attached to various classes are for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve been working in that direction, but I think it’s still a bit rough

# Construct and run operator
op = dv.Operator(pdes + [rec], subs=grid.spacing_map)
op(dt=dt0, time=nt)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the generated C code, I noticed the following:

        k00[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k01[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k02[x + 2][y + 2] = v[t0][x + 2][y + 2];
        k03[x + 2][y + 2] = v[t0][x + 2][y + 2];

Is this intended? It doesn't seem optimal since they are all assigned the same value

src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1/f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1/f0))**2)

# PDE system (2D acoustic)
system_eqs = [U[1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a simple heat equation would be a good example/test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I incorporated for the tests

Copy link
Contributor

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @fernanvr I think it's a great start!

I'm gonna leave some generic comments to start with and will refine as it gets updated. So here is a few main comments.

  1. As @EdCaunt , the MultiStage object should inherit from Eq as it represents an equation, or in this case a set of equations. It should also be moved to devito/types/equations.py or make a new devito/types/multistage.py file.

  2. Instead of that generic class with a self-referencing constructor, you should instead try to have a top level abstract class then specific cases for it. I.e something like:

class MultiStage(Eq):
    def __new__(cls, eq, method, target):

where target is the target field the same way as solve so that you can easily have forward and backward equations.

Then for RK you would have

class RK(MultiStage):
    a
    b
    c

a,b,c need to as class attributes so you can easily create each case as their own class so you can differentiate them if needed. e.g. RK4, RK3, etc. This way you can also have a MultiStage class that is not RK but something else, like Adams-Bashforth or whatever.

class RK4(RK)
    a = [0, 1/2, 1/2, 1]
    b = [1/6, 1/3, 1/3, 1/6]
    c = [0, 1/2, 1/2, 1]

You can also set it up as you did but the

   @classmethod
    def RK44(cls):

needs to be moved out of the class and made into a plain function.

  1. The MultiStage should implement .evaluate(expand=False) then you won't need the lower_multistage as the evaluate will directly return the list of equations needed.

  2. THe "test" you have is more of an example and should be made into a tutorial. The tests you be testing each piece separately such as the plain construction of each object and the results of the evaluation and lowering of it.

  3. Optionally. I would abstract it away and directly allow something like eq= solve(pde, u.forward, method='RK4') which would automatically create the right MultiStage object. That would allow for the user to easily switch between different time steppers.

Feel free to ask questions or discuss on slack as well

@fernanvr
Copy link
Author

Thank you all for the comprehensive review comments. They have been incredibly helpful in improving my coding skills and deepening my understanding of some of Devito’s workflows. Overall, I believe I have incorporated most of the suggestions to enhance the code. However, there are still a few points where I could use some additional guidance.

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More comments, but generally looking better

* If the object is MultiStage, it creates the stages of the method.
"""
lowered = []
for i, eq in enumerate(as_tuple(expressions)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having the the as_tuple here, why not have a dispatch for _lower_multistage that dispatches on iterable types as per _concretize_subdims?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, I think..

return [expr]


@_lower_multistage.register
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally tweak this for consistency with other uses of singledispatch in the codebase

@@ -15,7 +17,7 @@ class SolveError(Exception):
pass


def solve(eq, target, **kwargs):
def solve(eq, target, method = None, eq_num = 0, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kwargs should not have spaces around them. Furthermore, can method and eq_num simply be folded into **kwargs?

@@ -56,9 +58,15 @@ def solve(eq, target, **kwargs):

# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
sols_temp=target.new_from_mat(sols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have whitespace around operator. Same below

sols_temp=sols[0]

if method is not None:
method_cls = MultiStage._resolve_method(method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this implies thatMultiStage._resolve_method should not be a class method. Perhaps the method kwarg should be dropped, and instead you should have something like:

method_cls = eq._resolve_method()  # Possibly make this a property?
if method_cls is None:
    return sols_temp

return method_cls(sols_temp, target)._evaluate(eq_num=eq_num)

or even just

return eq._resolve_method(sols_temp, target)._evaluate(eq_num=eq_num)

where _resolve_method is some abstract method of Eq and its subclasses which defaults to a no-op or similar.

As a side note, why is the eq_num kwarg required?


# Create temporary Functions to hold each stage
# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are internal to Devito, should not appear in operator arguments, and should not be touched by the user, and so should use Array, not Function

- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
base_eq=self.eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the __new__ specified above, these would just be u = self.lhs.function etc, which would clean things up

assert len(self.a) == self.s, f"'a'={a} must have {self.s} rows"
assert len(self.c) == self.s, f"'c'={c} must have {self.s} elements"

def _evaluate(self, eq_num=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have **kwargs for consistency with Eq._evaluate()

c : list[float]
Time positions of intermediate stages.
"""
a = [[0, 0, 0, 0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would set these as tuples in the __init__. Definitely should not be mutable if set on a class level.

I would personally instead have a

def __init__(self):
    a = (...
    b = (...
    c = (...
    super.__init__(a=a, b=b, c=c)

from devito.types.multistage import MultiStage


def test_multistage_solve(time_int='RK44'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are looking better. I would suggest making some very granular tests to verify particular functionalities within MultiStage lowering

src_spatial * src_temporal]

# Time integration scheme
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don’t want to use return in your pytest tests. Instead, you should use assertions to verify that the result matches what you expect

# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
# k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs.get('sregistry').make_name(prefix='k') wants to be inside this loop to ensure that all names are unique

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments. Improving steadily once again

"""
Handle iterables of expressions.
"""
lowered = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?

return sols[0]
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

The right-hand side of the equation to integrate.
target : Function
The time-updated symbol on the left-hand side, e.g., `u` or `u.forward`.
method : str or None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be either a class or a callable imo. Alternatively, it should be entirely omitted and set by defining some method/_evaluate method. Of these two, I prefer the latter as it results in cleaner code and a simpler API.

In general, if you are using a string comparison, there is probably a better (and safer) way to achieve your aim.

def __init__(self, **kwargs):
self.a, self.b, self.c = self._validate(**kwargs)

def _validate(self, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use concrete args with type hinting rather than kwargs?

raise ValueError("RK subclass must define class attributes of the Butcher's array a, b, and c")
return a, b, c

@cached_property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Probably doesn't need caching

Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, a=None, b=None, c=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be def __init__(self, a: list[list[float | np.number]], b: list[float | np.number], c: list[float | np.number], **kwargs) -> None:, avoiding the need for the validation function

fernanvr added 2 commits June 26, 2025 18:10
…compatibility with pickles, and checking numerical convergence of three Runge-Kutta methods
Copy link

codecov bot commented Jul 10, 2025

Codecov Report

❌ Patch coverage is 25.52632% with 283 lines in your changes missing coverage. Please review.
✅ Project coverage is 45.80%. Comparing base (ed3585a) to head (ac1da7e).
⚠️ Report is 235 commits behind head on main.

Files with missing lines Patch % Lines
tests/test_multistage.py 10.09% 196 Missing ⚠️
devito/types/multistage.py 40.27% 86 Missing ⚠️
devito/ir/equations/algorithms.py 91.66% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (ed3585a) and HEAD (ac1da7e). Click for more details.

HEAD has 15 uploads less than BASE
Flag BASE (ed3585a) HEAD (ac1da7e)
17 4
pytest-gpu-nvc-nvidiaX 1 0
pytest-gpu-aomp-amdgpuX 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2599       +/-   ##
===========================================
- Coverage   92.00%   45.80%   -46.20%     
===========================================
  Files         245      250        +5     
  Lines       48727    50032     +1305     
  Branches     4294     4376       +82     
===========================================
- Hits        44830    22916    -21914     
- Misses       3209    26154    +22945     
- Partials      688      962      +274     
Flag Coverage Δ
pytest-gpu-aomp-amdgpuX ?
pytest-gpu-nvc-nvidiaX ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@@ -0,0 +1,357 @@
from .equation import Eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these imports absolute (devito.types.equation)


method_registry = {}

def register_method(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of using string matching here.

I'm also not sure why this function is needed, especially when the registry itself is just a regular dict


class MultiStage(Eq):
"""
Abstract base class for multi-stage time integration methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good docstring, but is it overindented by a level?

of update expressions for each stage in the integration process.
"""

def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This __new__ doesn't seem to do anything. Can it be removed?

Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], **kwargs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list[list[float | number]] should probably be defined as a type within the class to improve readability:

CoeffsBC = list[float | number]
CoeffsA = list[CoeffBC]
def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, **kwargs) -> None:

Returns
-------
list of Eq
A list of SymPy Eq objects representing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: they will be Devito Eq objects

c : list[float]
Time positions of intermediate stages.
"""
a = [[0, 0, 0, 0, 0, 0, 0, 0, 0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly a massive pain to do, but would it be possible to have some kind of calculation to get these coefficients rather than having all these tableaus defined in separate classes? In that way, you would just need to specify order and number of stages required



@register_method
class HORK(MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be the default Runge-Kutta unless there's something I'm overlooking imo

"""


def ssprk_alpha(mu=1, **kwargs):
Copy link
Contributor

@EdCaunt EdCaunt Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this doesn't need to live in the class, it should refer to self (perhaps using self.degree), or it should be a staticmethod/classmethod


for i in range(1, degree):
alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1]
alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(mu * list(range(1, i))) does this definitely do what it is intended to do?

alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1]
alpha[0] = 1 - sum(alpha[1:i + 1])

return alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look to be an array getting returned

- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""

u = self.lhs.function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There looks to be some repeated code in here vs other _evaluate methods. Perhaps a generic _evaluate for MultiStage or at least RK classes is in order with a suitable hook (or hooks) to grab the operations specific to that method

an_eq = range(len(U0))

# Compute SSPRK coefficients
alpha = np.array(ssprk_alpha(mu, degree), dtype=np.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the array into ssprk_alpha. dtype should be pulled from u

# Time integration scheme
pde = [resolve_method(time_int)(u_multi_stage, eq_rhs)]
op = Operator(pde, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again should assert a norm. Can also be consolidated with the previous test via parameterisation

op(dt=0.01, time=1)


@pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good! I would use the classes themselves rather than strings

pdes = [Eq(U[i].forward, solve(Eq(U[i].dt-eq_rhs[i]), U[i].forward)) for i in range(len(fun_labels))]
op = Operator(pdes, subs=grid.spacing_map)
op(dt=dt0, time=nt)
assert max(abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better check would be to compare the L2 norms with np.isclose. See other tests in the codebase for examples



@pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
def test_multistage_methods_convergence(time_int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly misleading test name unless I'm missing something? This sounds like a convergence test, but seems to actually test whether the multistage timestepper generates a solution which matches a more simple timestepping scheme

assert max(abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution"



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Too many blank lines at end of file

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments. More granular testing is still needed

@@ -79,6 +83,16 @@ def rhs(self):
"""Return list of right-hand sides."""
return self._rhs

@property
def deg(self):
"""Return list of right-hand sides."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For immutability, these should be tuple. Ditto with the next property. They should probably be made into tuples as early as possible


@property
def src(self):
"""Return list of right-hand sides."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated docstring?

@@ -115,7 +129,9 @@ class RK(MultiStage):
Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], lhs, rhs, **kwargs) -> None:
CoeffsBC = list[float | number]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does number superclass np.number?

@@ -132,19 +148,18 @@ def _evaluate(self, **kwargs):

Returns
-------
list of Eq
list of Devito Eq objects
A list of SymPy Eq objects representing:
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
n_eq=len(self.eq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs whitespace for flake8

A list of SymPy Eq objects representing:
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
n_eq=len(self.eq)
u = [i.function for i in self.lhs]
grid = [u[i].grid for i in range(n_eq)]
t = grid[0].time_dim
t = u[0].grid.time_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grids = {f.grid for f in u}
if not len(grids) == 1:
    raise ValueError("Cannot construct multi-stage time integrator for Functions on disparate grids")
grid = grids.pop()
t = grid.time_dim

would be safer

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=np.float64)
src_spatial.data[100, 100] = 1
import sympy as sym
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rogue import

# Time integration scheme
pdes = resolve_method(time_int)(U_multi_stage, system_eqs_rhs, source=src, degree=4)
op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.001, time=2000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs an assertion

assert np.max(np.abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution"


def test_multistage_coupled_op_computing_exp(time_int='HORK_EXP'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should pass a class, not a string.

@@ -219,22 +324,22 @@ def test_multistage_coupled_op_computing(time_int='RK44'):
op(dt=0.01, time=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also assert something

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should never use pickling for long-term file storage. Remove this file

Copy link
Contributor

@JDBetteridge JDBetteridge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice (once this has been iterated a bit further and the API concertised) to see an example or demonstrative notebook that showcases the new functionality.

@@ -36,7 +36,6 @@
disk_layer)
from devito.types.dimension import Thickness


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")


class RK(MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class RK(MultiStage):
class RungeKutta(MultiStage):

Comment on lines +183 to +184
@register_method
class RK44(RK):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think RK4 and indeed all RK methods should be instances of the RK Class

Then you no longer need all of the boilerplate code below, which is just setting up Butcher tableau

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or the coefficients should be class attributes and set by the child class



@register_method
class HORK(MultiStage):
class HORK_EXP(MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HighOrderRungeKuttaExplicit perhaps

Be explicit!

Use CamelCase for class definitions and names_with_underscores for variables and functions.



@register_method
class HORK(MultiStage):
class HORK_EXP(MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if we want to start distinguishing between explicit and implicit timesteppers we should be using a class hierarchy to do so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Comment on lines +94 to +97
# import matplotlib.pyplot as plt
# import numpy as np
# t=np.linspace(0,2000,1000)
# plt.plot(t,np.exp(1 - 2 * (t - 1)**2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftovers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants