-
Notifications
You must be signed in to change notification settings - Fork 239
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, many thanks for this!
As you may see from the "Files changed" tab (https://github.com/devitocodes/devito/pull/2599/files), there are several new files that shouldn't be there. The whole .idea folder has probably been pushed erroneously. You will have to rebase your git history and erase it.
The new file new_classes
should have a proper name and be placed in a different folder, @EdCaunt may have some ideas about that.
The test you have added has to be placed within a function def test_....
otherwise it won't be picked up by pytest
. And we are going to need more tests, one isn't enough.
Also, a test is a test, it shouldn't plot anything. Instead, the numerical output should be compared to either an analytical solution or to one or more precomputed ("by hand") norms.
devito/operator/operator.py
Outdated
@@ -271,6 +273,9 @@ def _lower(cls, expressions, **kwargs): | |||
# expression for which a partial or complete lowering is desired | |||
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) | |||
|
|||
# [MultiStage] -> [Eqs] | |||
expressions = cls._lower_multistage(expressions, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be called inside _lower_dsl
, which in turn is called within _lower_exprs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to be called within _lower_exprs
, but I couldn't find the _lower_dsl
function...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of comments, but looks like a good first attempt at multistage timestepping!
tests/test_multistage.py
Outdated
# Set logging level for debugging | ||
dv.configuration['log-level'] = 'DEBUG' | ||
|
||
# Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wants to use pytest
going forward. If you look at the existing tests, you will see that a lot of them are very granular, which aids bugfixing and test coverage.
It's worth noting that tests (especially the very elementary ones) can be physically nonsensical, in order to check that some aspect of the API or compiler functions as intended. For example, it would be useful to test that a range of Butcher tableaus get correctly assembled into the corresponding timestepping schemes, even if those timestepping schemes themselves are meaningless.
Examples of simplifications used for unit tests include using 1D where possible and using trivial Eq
s (i.e. Eq(f, 1)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other tests that we include for essentially every new class include ensuring that it can be pickled (see test_pickle.py) and that it can be rebuilt correctly using its _rebuild
method (also used for pickling). This is what the __rargs__
and __rkwargs__
attached to various classes are for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve been working in that direction, but I think it’s still a bit rough
tests/test_multistage.py
Outdated
# Construct and run operator | ||
op = dv.Operator(pdes + [rec], subs=grid.spacing_map) | ||
op(dt=dt0, time=nt) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the generated C code, I noticed the following:
k00[x + 2][y + 2] = v[t0][x + 2][y + 2];
k01[x + 2][y + 2] = v[t0][x + 2][y + 2];
k02[x + 2][y + 2] = v[t0][x + 2][y + 2];
k03[x + 2][y + 2] = v[t0][x + 2][y + 2];
Is this intended? It doesn't seem optimal since they are all assigned the same value
tests/test_multistage.py
Outdated
src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1/f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1/f0))**2) | ||
|
||
# PDE system (2D acoustic) | ||
system_eqs = [U[1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a simple heat equation would be a good example/test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I incorporated for the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @fernanvr I think it's a great start!
I'm gonna leave some generic comments to start with and will refine as it gets updated. So here is a few main comments.
-
As @EdCaunt , the
MultiStage
object should inherit fromEq
as it represents an equation, or in this case a set of equations. It should also be moved todevito/types/equations.py
or make a newdevito/types/multistage.py
file. -
Instead of that generic class with a self-referencing constructor, you should instead try to have a top level abstract class then specific cases for it. I.e something like:
class MultiStage(Eq):
def __new__(cls, eq, method, target):
where target
is the target field the same way as solve
so that you can easily have forward and backward equations.
Then for RK you would have
class RK(MultiStage):
a
b
c
a,b,c
need to as class attributes so you can easily create each case as their own class so you can differentiate them if needed. e.g. RK4
, RK3
, etc. This way you can also have a MultiStage
class that is not RK but something else, like Adams-Bashforth or whatever.
class RK4(RK)
a = [0, 1/2, 1/2, 1]
b = [1/6, 1/3, 1/3, 1/6]
c = [0, 1/2, 1/2, 1]
You can also set it up as you did but the
@classmethod
def RK44(cls):
needs to be moved out of the class and made into a plain function.
-
The
MultiStage
should implement.evaluate(expand=False)
then you won't need thelower_multistage
as theevaluate
will directly return the list of equations needed. -
THe "test" you have is more of an example and should be made into a tutorial. The tests you be testing each piece separately such as the plain construction of each object and the results of the evaluation and lowering of it.
-
Optionally. I would abstract it away and directly allow something like
eq= solve(pde, u.forward, method='RK4')
which would automatically create the rightMultiStage
object. That would allow for the user to easily switch between different time steppers.
Feel free to ask questions or discuss on slack as well
…tegrator the merge is necessary to maintaing the development of multistages
Thank you all for the comprehensive review comments. They have been incredibly helpful in improving my coding skills and deepening my understanding of some of Devito’s workflows. Overall, I believe I have incorporated most of the suggestions to enhance the code. However, there are still a few points where I could use some additional guidance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More comments, but generally looking better
devito/ir/equations/algorithms.py
Outdated
* If the object is MultiStage, it creates the stages of the method. | ||
""" | ||
lowered = [] | ||
for i, eq in enumerate(as_tuple(expressions)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than having the the as_tuple
here, why not have a dispatch for _lower_multistage
that dispatches on iterable types as per _concretize_subdims
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, I think..
devito/ir/equations/algorithms.py
Outdated
return [expr] | ||
|
||
|
||
@_lower_multistage.register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally tweak this for consistency with other uses of singledispatch
in the codebase
devito/operations/solve.py
Outdated
@@ -15,7 +17,7 @@ class SolveError(Exception): | |||
pass | |||
|
|||
|
|||
def solve(eq, target, **kwargs): | |||
def solve(eq, target, method = None, eq_num = 0, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kwargs should not have spaces around them. Furthermore, can method
and eq_num
simply be folded into **kwargs
?
devito/operations/solve.py
Outdated
@@ -56,9 +58,15 @@ def solve(eq, target, **kwargs): | |||
|
|||
# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions | |||
if len(sols) > 1: | |||
return target.new_from_mat(sols) | |||
sols_temp=target.new_from_mat(sols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have whitespace around operator. Same below
devito/operations/solve.py
Outdated
sols_temp=sols[0] | ||
|
||
if method is not None: | ||
method_cls = MultiStage._resolve_method(method) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, this implies thatMultiStage._resolve_method
should not be a class method. Perhaps the method
kwarg should be dropped, and instead you should have something like:
method_cls = eq._resolve_method() # Possibly make this a property?
if method_cls is None:
return sols_temp
return method_cls(sols_temp, target)._evaluate(eq_num=eq_num)
or even just
return eq._resolve_method(sols_temp, target)._evaluate(eq_num=eq_num)
where _resolve_method
is some abstract method of Eq
and its subclasses which defaults to a no-op or similar.
As a side note, why is the eq_num
kwarg required?
devito/types/multistage.py
Outdated
|
||
# Create temporary Functions to hold each stage | ||
# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are internal to Devito, should not appear in operator arguments, and should not be touched by the user, and so should use Array
, not Function
devito/types/multistage.py
Outdated
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||
""" | ||
base_eq=self.eq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the __new__
specified above, these would just be u = self.lhs.function
etc, which would clean things up
devito/types/multistage.py
Outdated
assert len(self.a) == self.s, f"'a'={a} must have {self.s} rows" | ||
assert len(self.c) == self.s, f"'c'={c} must have {self.s} elements" | ||
|
||
def _evaluate(self, eq_num=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have **kwargs
for consistency with Eq._evaluate()
c : list[float] | ||
Time positions of intermediate stages. | ||
""" | ||
a = [[0, 0, 0, 0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would set these as tuples in the __init__
. Definitely should not be mutable if set on a class level.
I would personally instead have a
def __init__(self):
a = (...
b = (...
c = (...
super.__init__(a=a, b=b, c=c)
tests/test_multistage.py
Outdated
from devito.types.multistage import MultiStage | ||
|
||
|
||
def test_multistage_solve(time_int='RK44'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are looking better. I would suggest making some very granular tests to verify particular functionalities within MultiStage
lowering
tests/test_multistage.py
Outdated
src_spatial * src_temporal] | ||
|
||
# Time integration scheme | ||
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don’t want to use return
in your pytest tests. Instead, you should use assertions to verify that the result matches what you expect
devito/types/multistage.py
Outdated
# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
# k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs.get('sregistry').make_name(prefix='k')
wants to be inside this loop to ensure that all names are unique
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments. Improving steadily once again
devito/ir/equations/algorithms.py
Outdated
""" | ||
Handle iterables of expressions. | ||
""" | ||
lowered = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you do return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]
?
return sols[0] | ||
sols_temp = sols[0] | ||
|
||
method = kwargs.get("method", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry
mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp)
here, which is tidier
devito/types/multistage.py
Outdated
The right-hand side of the equation to integrate. | ||
target : Function | ||
The time-updated symbol on the left-hand side, e.g., `u` or `u.forward`. | ||
method : str or None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be either a class or a callable imo. Alternatively, it should be entirely omitted and set by defining some method
/_evaluate
method. Of these two, I prefer the latter as it results in cleaner code and a simpler API.
In general, if you are using a string comparison, there is probably a better (and safer) way to achieve your aim.
devito/types/multistage.py
Outdated
def __init__(self, **kwargs): | ||
self.a, self.b, self.c = self._validate(**kwargs) | ||
|
||
def _validate(self, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use concrete args with type hinting rather than kwargs?
devito/types/multistage.py
Outdated
raise ValueError("RK subclass must define class attributes of the Butcher's array a, b, and c") | ||
return a, b, c | ||
|
||
@cached_property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Probably doesn't need caching
devito/types/multistage.py
Outdated
Number of stages in the RK method, inferred from `b`. | ||
""" | ||
|
||
def __init__(self, a=None, b=None, c=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be def __init__(self, a: list[list[float | np.number]], b: list[float | np.number], c: list[float | np.number], **kwargs) -> None:
, avoiding the need for the validation function
…compatibility with pickles, and checking numerical convergence of three Runge-Kutta methods
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2599 +/- ##
===========================================
- Coverage 92.00% 45.80% -46.20%
===========================================
Files 245 250 +5
Lines 48727 50032 +1305
Branches 4294 4376 +82
===========================================
- Hits 44830 22916 -21914
- Misses 3209 26154 +22945
- Partials 688 962 +274
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
devito/types/multistage.py
Outdated
@@ -0,0 +1,357 @@ | |||
from .equation import Eq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make these imports absolute (devito.types.equation
)
|
||
method_registry = {} | ||
|
||
def register_method(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of using string matching here.
I'm also not sure why this function is needed, especially when the registry itself is just a regular dict
devito/types/multistage.py
Outdated
|
||
class MultiStage(Eq): | ||
""" | ||
Abstract base class for multi-stage time integration methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good docstring, but is it overindented by a level?
devito/types/multistage.py
Outdated
of update expressions for each stage in the integration process. | ||
""" | ||
|
||
def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This __new__
doesn't seem to do anything. Can it be removed?
devito/types/multistage.py
Outdated
Number of stages in the RK method, inferred from `b`. | ||
""" | ||
|
||
def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list[list[float | number]]
should probably be defined as a type within the class to improve readability:
CoeffsBC = list[float | number]
CoeffsA = list[CoeffBC]
def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, **kwargs) -> None:
Returns | ||
------- | ||
list of Eq | ||
A list of SymPy Eq objects representing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: they will be Devito Eq
objects
c : list[float] | ||
Time positions of intermediate stages. | ||
""" | ||
a = [[0, 0, 0, 0, 0, 0, 0, 0, 0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly a massive pain to do, but would it be possible to have some kind of calculation to get these coefficients rather than having all these tableaus defined in separate classes? In that way, you would just need to specify order and number of stages required
devito/types/multistage.py
Outdated
|
||
|
||
@register_method | ||
class HORK(MultiStage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should just be the default Runge-Kutta unless there's something I'm overlooking imo
devito/types/multistage.py
Outdated
""" | ||
|
||
|
||
def ssprk_alpha(mu=1, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either this doesn't need to live in the class, it should refer to self
(perhaps using self.degree
), or it should be a staticmethod
/classmethod
devito/types/multistage.py
Outdated
|
||
for i in range(1, degree): | ||
alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1] | ||
alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(mu * list(range(1, i)))
does this definitely do what it is intended to do?
alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1] | ||
alpha[0] = 1 - sum(alpha[1:i + 1]) | ||
|
||
return alpha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look to be an array getting returned
devito/types/multistage.py
Outdated
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||
""" | ||
|
||
u = self.lhs.function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There looks to be some repeated code in here vs other _evaluate
methods. Perhaps a generic _evaluate
for MultiStage
or at least RK
classes is in order with a suitable hook (or hooks) to grab the operations specific to that method
devito/types/multistage.py
Outdated
an_eq = range(len(U0)) | ||
|
||
# Compute SSPRK coefficients | ||
alpha = np.array(ssprk_alpha(mu, degree), dtype=np.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the array
into ssprk_alpha
. dtype
should be pulled from u
# Time integration scheme | ||
pde = [resolve_method(time_int)(u_multi_stage, eq_rhs)] | ||
op = Operator(pde, subs=grid.spacing_map) | ||
op(dt=0.01, time=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again should assert a norm. Can also be consolidated with the previous test via parameterisation
op(dt=0.01, time=1) | ||
|
||
|
||
@pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good! I would use the classes themselves rather than strings
tests/test_multistage.py
Outdated
pdes = [Eq(U[i].forward, solve(Eq(U[i].dt-eq_rhs[i]), U[i].forward)) for i in range(len(fun_labels))] | ||
op = Operator(pdes, subs=grid.spacing_map) | ||
op(dt=dt0, time=nt) | ||
assert max(abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better check would be to compare the L2 norms with np.isclose
. See other tests in the codebase for examples
tests/test_multistage.py
Outdated
|
||
|
||
@pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) | ||
def test_multistage_methods_convergence(time_int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly misleading test name unless I'm missing something? This sounds like a convergence test, but seems to actually test whether the multistage timestepper generates a solution which matches a more simple timestepping scheme
assert max(abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution" | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Too many blank lines at end of file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments. More granular testing is still needed
@@ -79,6 +83,16 @@ def rhs(self): | |||
"""Return list of right-hand sides.""" | |||
return self._rhs | |||
|
|||
@property | |||
def deg(self): | |||
"""Return list of right-hand sides.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For immutability, these should be tuple
. Ditto with the next property. They should probably be made into tuple
s as early as possible
|
||
@property | ||
def src(self): | ||
"""Return list of right-hand sides.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repeated docstring?
@@ -115,7 +129,9 @@ class RK(MultiStage): | |||
Number of stages in the RK method, inferred from `b`. | |||
""" | |||
|
|||
def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], lhs, rhs, **kwargs) -> None: | |||
CoeffsBC = list[float | number] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does number
superclass np.number
?
@@ -132,19 +148,18 @@ def _evaluate(self, **kwargs): | |||
|
|||
Returns | |||
------- | |||
list of Eq | |||
list of Devito Eq objects | |||
A list of SymPy Eq objects representing: | |||
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | |||
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | |||
""" | |||
n_eq=len(self.eq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs whitespace for flake8
A list of SymPy Eq objects representing: | ||
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||
""" | ||
n_eq=len(self.eq) | ||
u = [i.function for i in self.lhs] | ||
grid = [u[i].grid for i in range(n_eq)] | ||
t = grid[0].time_dim | ||
t = u[0].grid.time_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grids = {f.grid for f in u}
if not len(grids) == 1:
raise ValueError("Cannot construct multi-stage time integrator for Functions on disparate grids")
grid = grids.pop()
t = grid.time_dim
would be safer
# Source definition | ||
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=np.float64) | ||
src_spatial.data[100, 100] = 1 | ||
import sympy as sym |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rogue import
# Time integration scheme | ||
pdes = resolve_method(time_int)(U_multi_stage, system_eqs_rhs, source=src, degree=4) | ||
op = Operator(pdes, subs=grid.spacing_map) | ||
op(dt=0.001, time=2000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs an assertion
assert np.max(np.abs(U[0].data[0,:]-U_multi_stage[0].data[0,:]))<10**-5, "the method is not converging to the solution" | ||
|
||
|
||
def test_multistage_coupled_op_computing_exp(time_int='HORK_EXP'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should pass a class, not a string.
@@ -219,22 +324,22 @@ def test_multistage_coupled_op_computing(time_int='RK44'): | |||
op(dt=0.01, time=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also assert something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should never use pickling for long-term file storage. Remove this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice (once this has been iterated a bit further and the API concertised) to see an example or demonstrative notebook that showcases the new functionality.
@@ -36,7 +36,6 @@ | |||
disk_layer) | |||
from devito.types.dimension import Thickness | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please run the linter (flake8
) 🙂
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") | ||
|
||
|
||
class RK(MultiStage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class RK(MultiStage): | |
class RungeKutta(MultiStage): |
@register_method | ||
class RK44(RK): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think RK4
and indeed all RK methods should be instances of the RK
Class
Then you no longer need all of the boilerplate code below, which is just setting up Butcher tableau
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or the coefficients should be class attributes and set by the child class
|
||
|
||
@register_method | ||
class HORK(MultiStage): | ||
class HORK_EXP(MultiStage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HighOrderRungeKuttaExplicit
perhaps
Be explicit!
Use CamelCase
for class definitions and names_with_underscores
for variables and functions.
|
||
|
||
@register_method | ||
class HORK(MultiStage): | ||
class HORK_EXP(MultiStage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if we want to start distinguishing between explicit and implicit timesteppers we should be using a class hierarchy to do so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file should be moved to somewhere like devito/timestepping/rungekutta.py
or devito/timestepping/explicitmultistage.py
that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)
# import matplotlib.pyplot as plt | ||
# import numpy as np | ||
# t=np.linspace(0,2000,1000) | ||
# plt.plot(t,np.exp(1 - 2 * (t - 1)**2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftovers?
Implementation of the ideas of @mloubout, as discussed in the Slack channel #timestepping.
It were created two new classes:
MultiStage: Each instance represents a combination of a PDE and its associated time integrator.
RK: Instances encapsulate a Butcher Tableau. The class methods define specific Runge-Kutta schemes, and it includes logic to expand a single PDE into multiple stage equations according to the RK method. It only contains the classic RK44, of fourth order and four stages, as a prove of concept.
To integrate the
MultiStage
into Devito’s pipeline, we modified the_lower(...)
function inoperator.py
by adding the line:expressions = cls._lower_multistage(expressions, **kwargs)
,and
modified _sanitize_exprs(cls, expressions, **kwargs)
to recognizeMultiStage
instances.We also created a new function:
_lower_multistage()
,in the same file. This function handles the expansion of
MultiStage
instances into their corresponding stage equations.