Skip to content

Commit a8d5757

Browse files
committed
compiler: prevent hosted per-thread arrays are dereferenced within partree at read
1 parent 2d32230 commit a8d5757

File tree

6 files changed

+34
-13
lines changed

6 files changed

+34
-13
lines changed

devito/finite_differences/differentiable.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ def __rfloordiv__(self, other):
259259
from .elementary import floor
260260
return floor(other / self)
261261

262+
def safe_inv(self, ref, safe=False):
263+
if safe:
264+
return SafeInv(self, ref or self)
265+
else:
266+
return 1 / self
267+
262268
def __mod__(self, other):
263269
return Mod(self, other)
264270

devito/passes/iet/parpragma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def _make_parregion(self, partree, parrays):
317317
i = n.write
318318
if not (i.is_Array or i.is_TempFunction):
319319
continue
320+
elif partree.dim in i.dimensions:
321+
# Non-local Array (full iteration space): no need to vector-expand
322+
continue
320323
elif i in parrays:
321324
pi = parrays[i]
322325
else:

devito/symbolics/inspection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sympy import (Function, Indexed, Integer, Mul, Number,
55
Pow, S, Symbol, Tuple)
66
from sympy.core.numbers import ImaginaryUnit
7+
from sympy.core.function import Application
78

89
from devito.finite_differences import Derivative
910
from devito.finite_differences.differentiable import IndexDerivative
@@ -116,7 +117,7 @@ def estimate_cost(exprs, estimate=False):
116117
estimate_values = {
117118
'elementary': 100,
118119
'pow': 50,
119-
'SafeInv': 10,
120+
'SafeInv': 50,
120121
'div': 5,
121122
'Abs': 5,
122123
'floor': 1,
@@ -211,6 +212,7 @@ def _(expr, estimate, seen):
211212

212213

213214
@_estimate_cost.register(Function)
215+
@_estimate_cost.register(Application)
214216
def _(expr, estimate, seen):
215217
if q_routine(expr):
216218
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
@@ -227,6 +229,7 @@ def _(expr, estimate, seen):
227229
flops += 1
228230
else:
229231
flops = 0
232+
230233
return flops, False
231234

232235

devito/types/basic.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def __init_finalize__(self, *args, **kwargs):
849849

850850
# Averaging mode for off the grid evaluation
851851
self._avg_mode = kwargs.get('avg_mode', 'arithmetic')
852-
if self._avg_mode not in ['arithmetic', 'harmonic']:
852+
if self._avg_mode not in ['arithmetic', 'harmonic', 'safe_harmonic']:
853853
raise ValueError("Invalid averaging mode_mode %s, accepted values are"
854854
" arithmetic or harmonic" % self._avg_mode)
855855

@@ -989,7 +989,7 @@ def c0(self):
989989
def _eval_deriv(self):
990990
return self
991991

992-
@cached_property
992+
@property
993993
def _grid_map(self):
994994
"""
995995
Mapper of off-grid interpolation points indices for each dimension.
@@ -1049,14 +1049,13 @@ def _evaluate(self, **kwargs):
10491049
return self
10501050

10511051
io = self.interp_order
1052-
if self._avg_mode == 'harmonic':
1053-
retval = 1 / self
1054-
else:
1055-
retval = self
1052+
retval = self.subs({i.subs(subs): self.indices_ref[d]
1053+
for d, i in mapper.items()})
1054+
if 'harmonic' in self._avg_mode:
1055+
retval = retval.safe_inv(retval, safe='safe' in self._avg_mode)
10561056

10571057
# Apply interpolation from inner most dim
10581058
for d, i in mapper.items():
1059-
retval = retval._subs(i.subs(subs), self.indices_ref[d])
10601059
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10611060

10621061
# Evaluate. Since we used `self.function` it will be on the grid when
@@ -1065,9 +1064,9 @@ def _evaluate(self, **kwargs):
10651064
retval = retval.subs(subs)
10661065

10671066
# If harmonic averaging, invert at the end
1068-
if self._avg_mode == 'harmonic':
1069-
from devito.finite_differences.differentiable import SafeInv
1070-
retval = SafeInv(retval, self.function.subs(subs))
1067+
if 'harmonic' in self._avg_mode:
1068+
retval = retval.safe_inv(self.function.subs(subs),
1069+
safe='safe' in self._avg_mode)
10711070

10721071
return retval
10731072

devito/types/dense.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,15 @@ def _time_buffering(self):
15351535
def _time_buffering_default(self):
15361536
return self._time_buffering and not isinstance(self.save, Buffer)
15371537

1538+
def _evaluate(self, **kwargs):
1539+
retval = super()._evaluate(**kwargs)
1540+
if not self._time_buffering and not retval.is_Function:
1541+
# Saved TimeFunction might need streaming, expand interpolations
1542+
# for easier processing.
1543+
return retval.evaluate
1544+
else:
1545+
return retval
1546+
15381547
def _arg_check(self, args, intervals, **kwargs):
15391548
super()._arg_check(args, intervals, **kwargs)
15401549

tests/test_differentiable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_avg_mode(ndim, io):
119119

120120
a0 = Function(name="a0", grid=grid, **kw)
121121
a = Function(name="a", grid=grid, **kw)
122-
b = Function(name="b", grid=grid, avg_mode='harmonic', **kw)
122+
b = Function(name="b", grid=grid, avg_mode='safe_harmonic', **kw)
123123

124124
a0_avg = a0._eval_at(v)
125125
a_avg = a._eval_at(v).evaluate.simplify()
@@ -141,7 +141,8 @@ def test_avg_mode(ndim, io):
141141
assert sympy.simplify(a_avg - expected) == 0
142142

143143
# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
144-
expected = (sum(c / b.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args)))
144+
expected = (sum(c * SafeInv(b.subs(arg), b.subs(arg))
145+
for c, arg in zip(ndcoeffs.flatten(), args)))
145146
assert sympy.simplify(b_avg.args[0] - expected) == 0
146147
assert isinstance(b_avg, SafeInv)
147148
assert b_avg.base == b

0 commit comments

Comments
 (0)