@@ -399,6 +399,65 @@ class HaloComms(Queue):
399399 def process (self , clusters ):
400400 return self ._process_fatd (clusters , 1 , seen = set ())
401401
402+ def _derive_halo_schemes (self , c ):
403+ hs = HaloScheme (c .exprs , c .ispace )
404+
405+ if not c .guards :
406+ yield hs , c
407+ return
408+
409+ # This is a more contrieved situation in which we might need halo exchanes
410+ # from multiple so called loc indices, and each of these will correspond
411+ # to a different HaloTouch
412+ for f , hse in hs .fmapper .items ():
413+ reads = c .scope .reads [f ]
414+
415+ for d in hse .loc_indices :
416+ if not d ._defines & set (c .guards ):
417+ continue
418+
419+ # E.g., {t0: [u(t0, x + 8, ...), u(t0, x + 7, ...)],
420+ # t1: [u(t1, x + 6, ...)]}
421+ for group in as_mapper (reads , key = lambda i : i [d ]).values ():
422+ foo = sympy .Function ('foo' )(* [i .access for i in group ])
423+ exprs = [e .func (rhs = foo ) for e in c .exprs ]
424+
425+ c1 = c .rebuild (exprs = exprs )
426+
427+ hs = HaloScheme (c1 .exprs , c .ispace )
428+
429+ yield hs , c1
430+
431+ def _make_halo_touch (self , hs , c , prefix ):
432+ d = prefix [- 1 ].dim
433+
434+ points = set ()
435+ for f in hs .fmapper :
436+ for a in c .scope .getreads (f ):
437+ points .add (a .access )
438+
439+ # We also add all written symbols to ultimately create mock WARs
440+ # with `c`, which will prevent the newly created HaloTouch to ever
441+ # be rescheduled after `c` upon topological sorting
442+ points .update (a .access for a in c .scope .accesses if a .is_write )
443+
444+ # Sort for determinism
445+ # NOTE: not sorting might impact code generation. The order of
446+ # the args is important because that's what search functions honor!
447+ points = sorted (points , key = str )
448+
449+ # Construct the HaloTouch Cluster
450+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
451+
452+ key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
453+ ispace = c .ispace .project (key )
454+ # HaloTouches are not parallel
455+ properties = c .properties .sequentialize ()
456+
457+ halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
458+
459+ return halo_touch
460+
402461 def callback (self , clusters , prefix , seen = None ):
403462 if not prefix :
404463 return clusters
@@ -412,38 +471,18 @@ def callback(self, clusters, prefix, seen=None):
412471 c in seen :
413472 continue
414473
415- hs = HaloScheme (c .exprs , c .ispace )
416- if hs .is_void or \
417- not d ._defines & hs .distributed_aindices :
418- continue
419-
420- points = set ()
421- for f in hs .fmapper :
422- for a in c .scope .getreads (f ):
423- points .add (a .access )
424-
425- # We also add all written symbols to ultimately create mock WARs
426- # with `c`, which will prevent the newly created HaloTouch to ever
427- # be rescheduled after `c` upon topological sorting
428- points .update (a .access for a in c .scope .accesses if a .is_write )
474+ seen .add (c )
429475
430- # Sort for determinism
431- # NOTE: not sorting might impact code generation. The order of
432- # the args is important because that's what search functions honor!
433- points = sorted (points , key = str )
434-
435- # Construct the HaloTouch Cluster
436- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
476+ for hs , c1 in self ._derive_halo_schemes (c ):
477+ if hs .is_void or \
478+ not d ._defines & hs .distributed_aindices :
479+ continue
437480
438- key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
439- ispace = c .ispace .project (key )
440- # HaloTouches are not parallel
441- properties = c .properties .sequentialize ()
481+ halo_touch = self ._make_halo_touch (hs , c1 , prefix )
442482
443- halo_touch = c . rebuild ( exprs = expr , ispace = ispace , properties = properties )
483+ processed . append ( halo_touch )
444484
445- processed .append (halo_touch )
446- seen .update ({halo_touch , c })
485+ seen .add (halo_touch )
447486
448487 processed .extend (clusters )
449488
0 commit comments