Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
frozendict)
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
ConditionalDimension, MultiStage)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.dimension import MultiSubDimension, Thickness
from devito.data.allocators import DataReference
from devito.logger import warning

__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']

__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']


def dimension_sort(expr):
Expand Down Expand Up @@ -95,6 +96,39 @@ def handle_indexed(indexed):
return ordering


def lower_multistage(expressions, **kwargs):
"""
Separating the multi-stage time-integrator scheme in stages:
* If the object is MultiStage, it creates the stages of the method.
"""
return _lower_multistage(expressions, **kwargs)


@singledispatch
def _lower_multistage(expr, **kwargs):
"""
Default handler for expressions that are not MultiStage.
Simply return them in a list.
"""
return [expr]


@_lower_multistage.register(MultiStage)
def _(expr, **kwargs):
"""
Specialized handler for MultiStage expressions.
"""
return expr._evaluate(**kwargs)


@_lower_multistage.register(Iterable)
def _(exprs, **kwargs):
"""
Handle iterables of expressions.
"""
return sum([_lower_multistage(expr, **kwargs) for expr in exprs], [])


def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
Expand Down
9 changes: 7 additions & 2 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from devito.finite_differences.derivative import Derivative
from devito.tools import as_tuple

from devito.types.multistage import resolve_method

__all__ = ['solve', 'linsolve']


Expand Down Expand Up @@ -56,9 +58,12 @@ def solve(eq, target, **kwargs):

# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
sols_temp = target.new_from_mat(sols)
else:
return sols[0]
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

return sols_temp if method is None else resolve_method(method)(target, sols_temp)


def linsolve(expr, target, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
InvalidOperator)
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
MetaCall, derive_parameters, iet_build)
Expand All @@ -36,7 +36,6 @@
disk_layer)
from devito.types.dimension import Thickness


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

__all__ = ['Operator']


Expand Down Expand Up @@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions = lower_multistage(expressions, **kwargs)

expand = kwargs['options'].get('expand', True)

# Specialization is performed on unevaluated expressions
Expand Down
2 changes: 2 additions & 0 deletions devito/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .relational import * # noqa
from .sparse import * # noqa
from .tensor import * # noqa

from .multistage import * # noqa
Loading
Loading