Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,83 @@ class HaloComms(Queue):
def process(self, clusters):
return self._process_fatd(clusters, 1, seen=set())

def _derive_halo_schemes(self, c):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new

hs = HaloScheme(c.exprs, c.ispace)

# 95% of the times we will just return `hs` as is as there are no guards
if not c.guards:
yield hs, c
return

# This is a more contrived situation in which we might need halo exchanges
# from multiple so called loc-indices -- let's check this out
candidates = []
for f, hse in hs.fmapper.items():
reads = c.scope.reads[f]

for d in hse.loc_indices:
if not d._defines & set(c.guards):
continue

candidates.append(as_mapper(reads, key=lambda i: i[d]).values())

# 4% of the times we will just return `hs` as is
# E.g., we end up here when taking space derivatives of one or more saved
# TimeFunctions in equations evaluating gradients that are controlled by
# a ConditionalDimension (otherwise we would have exited earlier)
if any(len(g) <= 1 for g in candidates):
yield hs, c
return

# 1% of the times, finally, we end up here...
# At this point we have to create a mock Cluster for each loc-index,
# containing all and only the accesses to `f` at a given loc-index
# E.g., a mock Cluster at `loc_index=t0` containing the accesses
# `[u(t0, x + 8, ...), u(t0, x + 7, ...)], another mock Cluster at
# `loc_index=t1` containing the accesses `[u(t1, x + 5, ...),
# u(t1, x + 6, ...)]`, and so on
for unordered_groups in candidates:
# Sort for deterministic code generation
groups = sorted(unordered_groups, key=str)
for group in groups:
pointset = sympy.Function('pointset')
v = pointset(*[i.access for i in group])
exprs = [e.func(rhs=v) for e in c.exprs]

c1 = c.rebuild(exprs=exprs)

hs = HaloScheme(c1.exprs, c.ispace)

yield hs, c1

def _make_halo_touch(self, hs, c, prefix):
points = set()
for f in hs.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)

# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch from
# ever being rescheduled
points.update(a.access for a in c.scope.accesses if a.is_write)

# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)

# Construct the HaloTouch Cluster
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))

key = lambda i: i in prefix[:-1] or i in hs.loc_indices
ispace = c.ispace.project(key)
# HaloTouches are not parallel
properties = c.properties.sequentialize()

halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)

return halo_touch

def callback(self, clusters, prefix, seen=None):
if not prefix:
return clusters
Expand All @@ -412,38 +489,18 @@ def callback(self, clusters, prefix, seen=None):
c in seen:
continue

hs = HaloScheme(c.exprs, c.ispace)
if hs.is_void or \
not d._defines & hs.distributed_aindices:
continue

points = set()
for f in hs.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)

# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch to ever
# be rescheduled after `c` upon topological sorting
points.update(a.access for a in c.scope.accesses if a.is_write)
seen.add(c)

# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)

# Construct the HaloTouch Cluster
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
for hs, c1 in self._derive_halo_schemes(c):
if hs.is_void or \
not d._defines & hs.distributed_aindices:
continue

key = lambda i: i in prefix[:-1] or i in hs.loc_indices
ispace = c.ispace.project(key)
# HaloTouches are not parallel
properties = c.properties.sequentialize()
halo_touch = self._make_halo_touch(hs, c1, prefix)

halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
processed.append(halo_touch)

processed.append(halo_touch)
seen.update({halo_touch, c})
seen.add(halo_touch)

processed.extend(clusters)

Expand Down
12 changes: 9 additions & 3 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,15 @@ def preprocess(clusters, options=None, **kwargs):
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))

if all(c1.ispace.is_subset(c.ispace) for c1 in found):
# 99% of the cases we end up here
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=hs))
if not any(c1.guards for c1 in found):
# 99% of the cases we end up here
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=hs))
else:
# We have to keep all HaloSchemes explicitly in separate
# Clusters or we might generate broken code by erroneously
# dropping halo exchanges on the floor
processed.extend((*found, c))
elif options['mpi']:
# We end up here with e.g. `t,x,y,z,f` where `f` is a sequential
# dimension requiring a loc-index in the HaloScheme. The compiler
Expand Down
15 changes: 2 additions & 13 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _merge_halospots(iet):
mapper = HaloSpotMapper()
for it, halo_spots in iter_mapper.items():
for hs0, hs1 in combinations(halo_spots, r=2):
if _check_control_flow(hs0, hs1, cond_mapper):
if cond_mapper.get(hs0) != cond_mapper.get(hs1):
continue

scope = _derive_scope(it, hs0, hs1)
Expand Down Expand Up @@ -191,7 +191,7 @@ def _hoist_invariant(iet):
mapper = HaloSpotMapper()
for it, halo_spots in iter_mapper.items():
for hs0, hs1 in combinations(halo_spots, r=2):
if _check_control_flow(hs0, hs1, cond_mapper):
if cond_mapper.get(hs0) or cond_mapper.get(hs1):
continue

scope = _derive_scope(it, hs0, hs1)
Expand Down Expand Up @@ -465,17 +465,6 @@ def _derive_scope(it, hs0, hs1):
return Scope(e.expr for e in expressions)


def _check_control_flow(hs0, hs1, cond_mapper):
"""
If there are Conditionals involved, both `hs0` and `hs1` must be
within the same Conditional, otherwise we would break control flow
"""
cond0 = cond_mapper.get(hs0)
cond1 = cond_mapper.get(hs1)

return cond0 != cond1


def _is_iter_carried(hsf, scope):
"""
True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces
Expand Down
21 changes: 20 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
SparseTimeFunction, VectorTimeFunction, TensorTimeFunction,
Dimension, ConditionalDimension, div, solve, diag, grad,
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
SubDimension, SubDomain, Eq, Ne, Gt, Inc, NODE, Operator, norm,
inner, configuration, switchconfig, generic_derivative,
PrecomputedSparseFunction, DefaultDimension, Buffer,
CustomDimension)
Expand Down Expand Up @@ -3128,6 +3128,25 @@ def test_interpolation_at_uforward(self, mode):
assert args[-2].name == 't2'
assert args[-2].origin == t + 1

@pytest.mark.parallel(mode=1)
def test_multiple_loc_indices_inside_conddim(self, mode):
grid = Grid(shape=(10, 10))
time = grid.time_dim

t_sub = ConditionalDimension('t_sub', parent=time, condition=Gt(time % 4))

f = Function(name='f', grid=grid, space_order=4)
u = TimeFunction(name='u', grid=grid, space_order=4)

eqns = [Eq(u.forward, u + 1),
Eq(f, u.dx + u.forward.dx + .2, implicit_dims=t_sub)]

op = Operator(eqns, opt=('advanced', {'openmp': False}))

calls, _ = check_halo_exchanges(op, 2, 2)
assert calls[0].arguments[-1].name == 't0'
assert calls[1].arguments[-1].name == 't1'


def gen_serial_norms(shape, so):
"""
Expand Down
Loading