@@ -399,6 +399,65 @@ class HaloComms(Queue):
399
399
def process (self , clusters ):
400
400
return self ._process_fatd (clusters , 1 , seen = set ())
401
401
402
+ def _derive_halo_schemes (self , c ):
403
+ hs = HaloScheme (c .exprs , c .ispace )
404
+
405
+ if not c .guards :
406
+ yield hs , c
407
+ return
408
+
409
+ # This is a more contrieved situation in which we might need halo exchanes
410
+ # from multiple so called loc indices, and each of these will correspond
411
+ # to a different HaloTouch
412
+ for f , hse in hs .fmapper .items ():
413
+ reads = c .scope .reads [f ]
414
+
415
+ for d in hse .loc_indices :
416
+ if not d ._defines & set (c .guards ):
417
+ continue
418
+
419
+ # E.g., {t0: [u(t0, x + 8, ...), u(t0, x + 7, ...)],
420
+ # t1: [u(t1, x + 6, ...)]}
421
+ for group in as_mapper (reads , key = lambda i : i [d ]).values ():
422
+ foo = sympy .Function ('foo' )(* [i .access for i in group ])
423
+ exprs = [e .func (rhs = foo ) for e in c .exprs ]
424
+
425
+ c1 = c .rebuild (exprs = exprs )
426
+
427
+ hs = HaloScheme (c1 .exprs , c .ispace )
428
+
429
+ yield hs , c1
430
+
431
+ def _make_halo_touch (self , hs , c , prefix ):
432
+ d = prefix [- 1 ].dim
433
+
434
+ points = set ()
435
+ for f in hs .fmapper :
436
+ for a in c .scope .getreads (f ):
437
+ points .add (a .access )
438
+
439
+ # We also add all written symbols to ultimately create mock WARs
440
+ # with `c`, which will prevent the newly created HaloTouch to ever
441
+ # be rescheduled after `c` upon topological sorting
442
+ points .update (a .access for a in c .scope .accesses if a .is_write )
443
+
444
+ # Sort for determinism
445
+ # NOTE: not sorting might impact code generation. The order of
446
+ # the args is important because that's what search functions honor!
447
+ points = sorted (points , key = str )
448
+
449
+ # Construct the HaloTouch Cluster
450
+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
451
+
452
+ key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
453
+ ispace = c .ispace .project (key )
454
+ # HaloTouches are not parallel
455
+ properties = c .properties .sequentialize ()
456
+
457
+ halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
458
+
459
+ return halo_touch
460
+
402
461
def callback (self , clusters , prefix , seen = None ):
403
462
if not prefix :
404
463
return clusters
@@ -412,38 +471,18 @@ def callback(self, clusters, prefix, seen=None):
412
471
c in seen :
413
472
continue
414
473
415
- hs = HaloScheme (c .exprs , c .ispace )
416
- if hs .is_void or \
417
- not d ._defines & hs .distributed_aindices :
418
- continue
419
-
420
- points = set ()
421
- for f in hs .fmapper :
422
- for a in c .scope .getreads (f ):
423
- points .add (a .access )
424
-
425
- # We also add all written symbols to ultimately create mock WARs
426
- # with `c`, which will prevent the newly created HaloTouch to ever
427
- # be rescheduled after `c` upon topological sorting
428
- points .update (a .access for a in c .scope .accesses if a .is_write )
474
+ seen .add (c )
429
475
430
- # Sort for determinism
431
- # NOTE: not sorting might impact code generation. The order of
432
- # the args is important because that's what search functions honor!
433
- points = sorted (points , key = str )
434
-
435
- # Construct the HaloTouch Cluster
436
- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
476
+ for hs , c1 in self ._derive_halo_schemes (c ):
477
+ if hs .is_void or \
478
+ not d ._defines & hs .distributed_aindices :
479
+ continue
437
480
438
- key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
439
- ispace = c .ispace .project (key )
440
- # HaloTouches are not parallel
441
- properties = c .properties .sequentialize ()
481
+ halo_touch = self ._make_halo_touch (hs , c1 , prefix )
442
482
443
- halo_touch = c . rebuild ( exprs = expr , ispace = ispace , properties = properties )
483
+ processed . append ( halo_touch )
444
484
445
- processed .append (halo_touch )
446
- seen .update ({halo_touch , c })
485
+ seen .add (halo_touch )
447
486
448
487
processed .extend (clusters )
449
488
0 commit comments