Skip to content

Commit 4649f05

Browse files
committed
compiler: Fix MPI with CondDim + multiple loc_indices
1 parent cfc5a08 commit 4649f05

File tree

4 files changed

+116
-45
lines changed

4 files changed

+116
-45
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,83 @@ class HaloComms(Queue):
399399
def process(self, clusters):
400400
return self._process_fatd(clusters, 1, seen=set())
401401

402+
def _derive_halo_schemes(self, c):
403+
hs = HaloScheme(c.exprs, c.ispace)
404+
405+
# 95% of the times we will just return `hs` as is as there are no guards
406+
if not c.guards:
407+
yield hs, c
408+
return
409+
410+
# This is a more contrived situation in which we might need halo exchanges
411+
# from multiple so called loc-indices -- let's check this out
412+
candidates = []
413+
for f, hse in hs.fmapper.items():
414+
reads = c.scope.reads[f]
415+
416+
for d in hse.loc_indices:
417+
if not d._defines & set(c.guards):
418+
continue
419+
420+
candidates.append(as_mapper(reads, key=lambda i: i[d]).values())
421+
422+
# 4% of the times we will just return `hs` as is
423+
# E.g., we end up here when taking space derivatives of one or more saved
424+
# TimeFunctions in equations evaluating gradients that are controlled by
425+
# a ConditionalDimension (otherwise we would have exited earlier)
426+
if any(len(g) <= 1 for g in candidates):
427+
yield hs, c
428+
return
429+
430+
# 1% of the times, finally, we end up here...
431+
# At this point we have to create a mock Cluster for each loc-index,
432+
# containing all and only the accesses to `f` at a given loc-index
433+
# E.g., a mock Cluster at `loc_index=t0` containing the accesses
434+
# `[u(t0, x + 8, ...), u(t0, x + 7, ...)], another mock Cluster at
435+
# `loc_index=t1` containing the accesses `[u(t1, x + 5, ...),
436+
# u(t1, x + 6, ...)]`, and so on
437+
for unordered_groups in candidates:
438+
# Sort for deterministic code generation
439+
groups = sorted(unordered_groups, key=str)
440+
for group in groups:
441+
pointset = sympy.Function('pointset')
442+
v = pointset(*[i.access for i in group])
443+
exprs = [e.func(rhs=v) for e in c.exprs]
444+
445+
c1 = c.rebuild(exprs=exprs)
446+
447+
hs = HaloScheme(c1.exprs, c.ispace)
448+
449+
yield hs, c1
450+
451+
def _make_halo_touch(self, hs, c, prefix):
452+
points = set()
453+
for f in hs.fmapper:
454+
for a in c.scope.getreads(f):
455+
points.add(a.access)
456+
457+
# We also add all written symbols to ultimately create mock WARs
458+
# with `c`, which will prevent the newly created HaloTouch from
459+
# ever being rescheduled
460+
points.update(a.access for a in c.scope.accesses if a.is_write)
461+
462+
# Sort for determinism
463+
# NOTE: not sorting might impact code generation. The order of
464+
# the args is important because that's what search functions honor!
465+
points = sorted(points, key=str)
466+
467+
# Construct the HaloTouch Cluster
468+
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
469+
470+
key = lambda i: i in prefix[:-1] or i in hs.loc_indices
471+
ispace = c.ispace.project(key)
472+
# HaloTouches are not parallel
473+
properties = c.properties.sequentialize()
474+
475+
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
476+
477+
return halo_touch
478+
402479
def callback(self, clusters, prefix, seen=None):
403480
if not prefix:
404481
return clusters
@@ -412,38 +489,18 @@ def callback(self, clusters, prefix, seen=None):
412489
c in seen:
413490
continue
414491

415-
hs = HaloScheme(c.exprs, c.ispace)
416-
if hs.is_void or \
417-
not d._defines & hs.distributed_aindices:
418-
continue
419-
420-
points = set()
421-
for f in hs.fmapper:
422-
for a in c.scope.getreads(f):
423-
points.add(a.access)
424-
425-
# We also add all written symbols to ultimately create mock WARs
426-
# with `c`, which will prevent the newly created HaloTouch to ever
427-
# be rescheduled after `c` upon topological sorting
428-
points.update(a.access for a in c.scope.accesses if a.is_write)
492+
seen.add(c)
429493

430-
# Sort for determinism
431-
# NOTE: not sorting might impact code generation. The order of
432-
# the args is important because that's what search functions honor!
433-
points = sorted(points, key=str)
434-
435-
# Construct the HaloTouch Cluster
436-
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
494+
for hs, c1 in self._derive_halo_schemes(c):
495+
if hs.is_void or \
496+
not d._defines & hs.distributed_aindices:
497+
continue
437498

438-
key = lambda i: i in prefix[:-1] or i in hs.loc_indices
439-
ispace = c.ispace.project(key)
440-
# HaloTouches are not parallel
441-
properties = c.properties.sequentialize()
499+
halo_touch = self._make_halo_touch(hs, c1, prefix)
442500

443-
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
501+
processed.append(halo_touch)
444502

445-
processed.append(halo_touch)
446-
seen.update({halo_touch, c})
503+
seen.add(halo_touch)
447504

448505
processed.extend(clusters)
449506

devito/ir/stree/algorithms.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,15 @@ def preprocess(clusters, options=None, **kwargs):
204204
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))
205205

206206
if all(c1.ispace.is_subset(c.ispace) for c1 in found):
207-
# 99% of the cases we end up here
208-
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
209-
processed.append(c.rebuild(halo_scheme=hs))
207+
if not any(c1.guards for c1 in found):
208+
# 99% of the cases we end up here
209+
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
210+
processed.append(c.rebuild(halo_scheme=hs))
211+
else:
212+
# We have to keep all HaloSchemes explicitly in separate
213+
# Clusters or we might generate broken code by erroneously
214+
# dropping halo exchanges on the floor
215+
processed.extend((*found, c))
210216
elif options['mpi']:
211217
# We end up here with e.g. `t,x,y,z,f` where `f` is a sequential
212218
# dimension requiring a loc-index in the HaloScheme. The compiler

devito/passes/iet/mpi.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _merge_halospots(iet):
141141
mapper = HaloSpotMapper()
142142
for it, halo_spots in iter_mapper.items():
143143
for hs0, hs1 in combinations(halo_spots, r=2):
144-
if _check_control_flow(hs0, hs1, cond_mapper):
144+
if cond_mapper.get(hs0) != cond_mapper.get(hs1):
145145
continue
146146

147147
scope = _derive_scope(it, hs0, hs1)
@@ -191,7 +191,7 @@ def _hoist_invariant(iet):
191191
mapper = HaloSpotMapper()
192192
for it, halo_spots in iter_mapper.items():
193193
for hs0, hs1 in combinations(halo_spots, r=2):
194-
if _check_control_flow(hs0, hs1, cond_mapper):
194+
if cond_mapper.get(hs0) or cond_mapper.get(hs1):
195195
continue
196196

197197
scope = _derive_scope(it, hs0, hs1)
@@ -465,17 +465,6 @@ def _derive_scope(it, hs0, hs1):
465465
return Scope(e.expr for e in expressions)
466466

467467

468-
def _check_control_flow(hs0, hs1, cond_mapper):
469-
"""
470-
If there are Conditionals involved, both `hs0` and `hs1` must be
471-
within the same Conditional, otherwise we would break control flow
472-
"""
473-
cond0 = cond_mapper.get(hs0)
474-
cond1 = cond_mapper.get(hs1)
475-
476-
return cond0 != cond1
477-
478-
479468
def _is_iter_carried(hsf, scope):
480469
"""
481470
True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces

tests/test_mpi.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
77
SparseTimeFunction, VectorTimeFunction, TensorTimeFunction,
88
Dimension, ConditionalDimension, div, solve, diag, grad,
9-
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
9+
SubDimension, SubDomain, Eq, Ne, Gt, Inc, NODE, Operator, norm,
1010
inner, configuration, switchconfig, generic_derivative,
1111
PrecomputedSparseFunction, DefaultDimension, Buffer,
1212
CustomDimension)
@@ -3128,6 +3128,25 @@ def test_interpolation_at_uforward(self, mode):
31283128
assert args[-2].name == 't2'
31293129
assert args[-2].origin == t + 1
31303130

3131+
@pytest.mark.parallel(mode=1)
3132+
def test_multiple_loc_indices_inside_conddim(self, mode):
3133+
grid = Grid(shape=(10, 10))
3134+
time = grid.time_dim
3135+
3136+
t_sub = ConditionalDimension('t_sub', parent=time, condition=Gt(time % 4))
3137+
3138+
f = Function(name='f', grid=grid, space_order=4)
3139+
u = TimeFunction(name='u', grid=grid, space_order=4)
3140+
3141+
eqns = [Eq(u.forward, u + 1),
3142+
Eq(f, u.dx + u.forward.dx + .2, implicit_dims=t_sub)]
3143+
3144+
op = Operator(eqns, opt=('advanced', {'openmp': False}))
3145+
3146+
calls, _ = check_halo_exchanges(op, 2, 2)
3147+
assert calls[0].arguments[-1].name == 't0'
3148+
assert calls[1].arguments[-1].name == 't1'
3149+
31313150

31323151
def gen_serial_norms(shape, so):
31333152
"""

0 commit comments

Comments
 (0)