@@ -399,6 +399,65 @@ class HaloComms(Queue):
399399 def process(self, clusters):
400400 return self._process_fatd(clusters, 1, seen=set())
401401
402+ def _derive_halo_schemes(self, c):
403+ hs = HaloScheme(c.exprs, c.ispace)
404+
405+ if not c.guards:
406+ yield hs, c
407+ return
408+
409+ # This is a more contrieved situation in which we might need halo exchanes
410+ # from multiple so called loc indices, and each of these will correspond
411+ # to a different HaloTouch
412+ for f, hse in hs.fmapper.items():
413+ reads = c.scope.reads[f]
414+
415+ for d in hse.loc_indices:
416+ if not d._defines & set(c.guards):
417+ continue
418+
419+ # E.g., {t0: [u(t0, x + 8, ...), u(t0, x + 7, ...)],
420+ # t1: [u(t1, x + 6, ...)]}
421+ for group in as_mapper(reads, key=lambda i: i[d]).values():
422+ foo = sympy.Function('foo')(*[i.access for i in group])
423+ exprs = [e.func(rhs=foo) for e in c.exprs]
424+
425+ c1 = c.rebuild(exprs=exprs)
426+
427+ hs = HaloScheme(c1.exprs, c.ispace)
428+
429+ yield hs, c1
430+
431+ def _make_halo_touch(self, hs, c, prefix):
432+ d = prefix[-1].dim
433+
434+ points = set()
435+ for f in hs.fmapper:
436+ for a in c.scope.getreads(f):
437+ points.add(a.access)
438+
439+ # We also add all written symbols to ultimately create mock WARs
440+ # with `c`, which will prevent the newly created HaloTouch to ever
441+ # be rescheduled after `c` upon topological sorting
442+ points.update(a.access for a in c.scope.accesses if a.is_write)
443+
444+ # Sort for determinism
445+ # NOTE: not sorting might impact code generation. The order of
446+ # the args is important because that's what search functions honor!
447+ points = sorted(points, key=str)
448+
449+ # Construct the HaloTouch Cluster
450+ expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
451+
452+ key = lambda i: i in prefix[:-1] or i in hs.loc_indices
453+ ispace = c.ispace.project(key)
454+ # HaloTouches are not parallel
455+ properties = c.properties.sequentialize()
456+
457+ halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
458+
459+ return halo_touch
460+
402461 def callback(self, clusters, prefix, seen=None):
403462 if not prefix:
404463 return clusters
@@ -412,38 +471,18 @@ def callback(self, clusters, prefix, seen=None):
412471 c in seen:
413472 continue
414473
415- hs = HaloScheme(c.exprs, c.ispace)
416- if hs.is_void or \
417- not d._defines & hs.distributed_aindices:
418- continue
419-
420- points = set()
421- for f in hs.fmapper:
422- for a in c.scope.getreads(f):
423- points.add(a.access)
424-
425- # We also add all written symbols to ultimately create mock WARs
426- # with `c`, which will prevent the newly created HaloTouch to ever
427- # be rescheduled after `c` upon topological sorting
428- points.update(a.access for a in c.scope.accesses if a.is_write)
474+ seen.add(c)
429475
430- # Sort for determinism
431- # NOTE: not sorting might impact code generation. The order of
432- # the args is important because that's what search functions honor!
433- points = sorted(points, key=str)
434-
435- # Construct the HaloTouch Cluster
436- expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
476+ for hs, c1 in self._derive_halo_schemes(c):
477+ if hs.is_void or \
478+ not d._defines & hs.distributed_aindices:
479+ continue
437480
438- key = lambda i: i in prefix[:-1] or i in hs.loc_indices
439- ispace = c.ispace.project(key)
440- # HaloTouches are not parallel
441- properties = c.properties.sequentialize()
481+ halo_touch = self._make_halo_touch(hs, c1, prefix)
442482
443- halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties )
483+ processed.append(halo_touch )
444484
445- processed.append(halo_touch)
446- seen.update({halo_touch, c})
485+ seen.add(halo_touch)
447486
448487 processed.extend(clusters)
449488
0 commit comments