@@ -399,6 +399,83 @@ class HaloComms(Queue):
399
399
def process (self , clusters ):
400
400
return self ._process_fatd (clusters , 1 , seen = set ())
401
401
402
+ def _derive_halo_schemes (self , c ):
403
+ hs = HaloScheme (c .exprs , c .ispace )
404
+
405
+ # 95% of the times we will just return `hs` as is as there are no guards
406
+ if not c .guards :
407
+ yield hs , c
408
+ return
409
+
410
+ # This is a more contrived situation in which we might need halo exchanges
411
+ # from multiple so called loc-indices -- let's check this out
412
+ candidates = []
413
+ for f , hse in hs .fmapper .items ():
414
+ reads = c .scope .reads [f ]
415
+
416
+ for d in hse .loc_indices :
417
+ if not d ._defines & set (c .guards ):
418
+ continue
419
+
420
+ candidates .append (as_mapper (reads , key = lambda i : i [d ]).values ())
421
+
422
+ # 4% of the times we will just return `hs` as is
423
+ # E.g., we end up here when taking space derivatives of one or more saved
424
+ # TimeFunctions in equations evaluating gradients that are controlled by
425
+ # a ConditionalDimension (otherwise we would have exited earlier)
426
+ if any (len (g ) <= 1 for g in candidates ):
427
+ yield hs , c
428
+ return
429
+
430
+ # 1% of the times, finally, we end up here...
431
+ # At this point we have to create a mock Cluster for each loc-index,
432
+ # containing all and only the accesses to `f` at a given loc-index
433
+ # E.g., a mock Cluster at `loc_index=t0` containing the accesses
434
+ # `[u(t0, x + 8, ...), u(t0, x + 7, ...)], another mock Cluster at
435
+ # `loc_index=t1` containing the accesses `[u(t1, x + 5, ...),
436
+ # u(t1, x + 6, ...)]`, and so on
437
+ for unordered_groups in candidates :
438
+ # Sort for deterministic code generation
439
+ groups = sorted (unordered_groups , key = str )
440
+ for group in groups :
441
+ pointset = sympy .Function ('pointset' )
442
+ v = pointset (* [i .access for i in group ])
443
+ exprs = [e .func (rhs = v ) for e in c .exprs ]
444
+
445
+ c1 = c .rebuild (exprs = exprs )
446
+
447
+ hs = HaloScheme (c1 .exprs , c .ispace )
448
+
449
+ yield hs , c1
450
+
451
+ def _make_halo_touch (self , hs , c , prefix ):
452
+ points = set ()
453
+ for f in hs .fmapper :
454
+ for a in c .scope .getreads (f ):
455
+ points .add (a .access )
456
+
457
+ # We also add all written symbols to ultimately create mock WARs
458
+ # with `c`, which will prevent the newly created HaloTouch from
459
+ # ever being rescheduled
460
+ points .update (a .access for a in c .scope .accesses if a .is_write )
461
+
462
+ # Sort for determinism
463
+ # NOTE: not sorting might impact code generation. The order of
464
+ # the args is important because that's what search functions honor!
465
+ points = sorted (points , key = str )
466
+
467
+ # Construct the HaloTouch Cluster
468
+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
469
+
470
+ key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
471
+ ispace = c .ispace .project (key )
472
+ # HaloTouches are not parallel
473
+ properties = c .properties .sequentialize ()
474
+
475
+ halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
476
+
477
+ return halo_touch
478
+
402
479
def callback (self , clusters , prefix , seen = None ):
403
480
if not prefix :
404
481
return clusters
@@ -412,38 +489,18 @@ def callback(self, clusters, prefix, seen=None):
412
489
c in seen :
413
490
continue
414
491
415
- hs = HaloScheme (c .exprs , c .ispace )
416
- if hs .is_void or \
417
- not d ._defines & hs .distributed_aindices :
418
- continue
419
-
420
- points = set ()
421
- for f in hs .fmapper :
422
- for a in c .scope .getreads (f ):
423
- points .add (a .access )
424
-
425
- # We also add all written symbols to ultimately create mock WARs
426
- # with `c`, which will prevent the newly created HaloTouch to ever
427
- # be rescheduled after `c` upon topological sorting
428
- points .update (a .access for a in c .scope .accesses if a .is_write )
492
+ seen .add (c )
429
493
430
- # Sort for determinism
431
- # NOTE: not sorting might impact code generation. The order of
432
- # the args is important because that's what search functions honor!
433
- points = sorted (points , key = str )
434
-
435
- # Construct the HaloTouch Cluster
436
- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
494
+ for hs , c1 in self ._derive_halo_schemes (c ):
495
+ if hs .is_void or \
496
+ not d ._defines & hs .distributed_aindices :
497
+ continue
437
498
438
- key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
439
- ispace = c .ispace .project (key )
440
- # HaloTouches are not parallel
441
- properties = c .properties .sequentialize ()
499
+ halo_touch = self ._make_halo_touch (hs , c1 , prefix )
442
500
443
- halo_touch = c . rebuild ( exprs = expr , ispace = ispace , properties = properties )
501
+ processed . append ( halo_touch )
444
502
445
- processed .append (halo_touch )
446
- seen .update ({halo_touch , c })
503
+ seen .add (halo_touch )
447
504
448
505
processed .extend (clusters )
449
506
0 commit comments