Skip to content

Commit fc4ea00

Browse files
committed
compiler: fix complex reductions for gnu
1 parent a49020d commit fc4ea00

File tree

8 files changed

+209
-18
lines changed

8 files changed

+209
-18
lines changed

conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def skipif(items, whole_module=False):
3535
accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc',
3636
'device-aomp', 'cpu64-icc', 'cpu64-icx', 'cpu64-nvc',
3737
'noadvisor', 'cpu64-arm', 'cpu64-icpx', 'chkpnt'})
38-
accepted.update({'nodevice'})
38+
accepted.update({'nodevice', 'noomp'})
3939
unknown = sorted(set(items) - accepted)
4040
if unknown:
4141
raise ValueError("Illegal skipif argument(s) `%s`" % unknown)
@@ -86,6 +86,10 @@ def skipif(items, whole_module=False):
8686
not get_advisor_path()):
8787
skipit = "Only `icx+advisor` should be tested here"
8888
break
89+
# Slip if not using openmp
90+
if i == 'noomp' and 'openmp' not in configuration['language']:
91+
skipit = "Must use openmp"
92+
break
8993
# Skip if it won't run on Arm
9094
if i == 'cpu64-arm' and isinstance(configuration['platform'], Arm):
9195
skipit = "Arm doesn't support x86-specific instructions"

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
def cleanup():
4141
devito_mpi_finalize()
4242
atexit.register(cleanup)
43-
except ImportError as e:
43+
except (RuntimeError, ImportError) as e:
4444
# Dummy fallback in case mpi4py/MPI aren't available
4545
class NoneMetaclass(type):
4646
def __getattr__(self, name):

devito/passes/iet/languages/C.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,67 @@
11
import numpy as np
22
from sympy.printing.c import C99CodePrinter
33

4-
from devito.ir import Call, BasePrinter
4+
from devito.exceptions import InvalidOperator
5+
from devito.ir import Call, BasePrinter, List
56
from devito.passes.iet.definitions import DataManager
67
from devito.passes.iet.orchestration import Orchestrator
78
from devito.passes.iet.langbase import LangBB
89
from devito.symbolics import c_complex, c_double_complex
10+
from devito.symbolics.extended_sympy import UnaryOp
911
from devito.tools import dtype_to_cstr
1012

1113
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
1214

1315

16+
class RealExt(UnaryOp):
17+
18+
_op = '__real__ '
19+
20+
21+
class ImagExt(UnaryOp):
22+
23+
_op = '__imag__ '
24+
25+
26+
def atomic_add(i, pragmas, split=False):
27+
# Base case, real reduction
28+
if not split:
29+
return i._rebuild(pragmas=pragmas)
30+
# Complex reduction, split using a temp pointer
31+
# Transforns lhs += rhs into
32+
# {
33+
# pragmas
34+
# __real__ lhs += __real__ rhs;
35+
# pragmas
36+
# __imag__ lhs += __imag__ rhs;
37+
# }
38+
lhs, rhs = i.expr.lhs, i.expr.rhs
39+
if (np.issubdtype(lhs.dtype, np.complexfloating)
40+
and np.issubdtype(rhs.dtype, np.complexfloating)):
41+
# Complex i, complex j
42+
# Atomic add real and imaginary parts separately
43+
lhsr, rhsr = RealExt(lhs), RealExt(rhs)
44+
lhsi, rhsi = ImagExt(lhs), ImagExt(rhs)
45+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
46+
pragmas=pragmas)
47+
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
48+
pragmas=pragmas)
49+
return List(body=[real, imag])
50+
51+
elif (np.issubdtype(lhs.dtype, np.complexfloating)
52+
and not np.issubdtype(rhs.dtype, np.complexfloating)):
53+
# Complex i, real j
54+
# Atomic add j to real part of i
55+
lhsr, rhsr = RealExt(lhs), rhs
56+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
57+
pragmas=pragmas)
58+
return real
59+
else:
60+
# Real i, complex j
61+
raise InvalidOperator("Atomic add not implemented for real "
62+
"Functions with complex increments")
63+
64+
1465
class CBB(LangBB):
1566

1667
mapper = {
@@ -29,7 +80,7 @@ class CBB(LangBB):
2980
'host-free-pin': lambda i:
3081
Call('free', (i,)),
3182
'alloc-global-symbol': lambda i, j, k:
32-
Call('memcpy', (i, j, k)),
83+
Call('memcpy', (i, j, k))
3384
}
3485

3586

devito/passes/iet/languages/CXX.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
from ctypes import POINTER
2+
13
import numpy as np
24
from sympy.printing.cxx import CXX11CodePrinter
35

4-
from devito.ir import Call, UsingNamespace, BasePrinter
6+
from devito import Real, Imag
7+
from devito.exceptions import InvalidOperator
8+
from devito.ir import Call, UsingNamespace, BasePrinter, DummyExpr, List
59
from devito.passes.iet.definitions import DataManager
610
from devito.passes.iet.orchestration import Orchestrator
711
from devito.passes.iet.langbase import LangBB
8-
from devito.symbolics import c_complex, c_double_complex
9-
from devito.tools import dtype_to_cstr
12+
from devito.symbolics import c_complex, c_double_complex, IndexedPointer, cast, Byref
13+
from devito.tools import dtype_to_cstr, dtype_to_ctype
14+
from devito.types import Pointer
1015

1116
__all__ = ['CXXBB', 'CXXDataManager', 'CXXOrchestrator']
1217

@@ -65,6 +70,51 @@ def std_arith(prefix=None):
6570
"""
6671

6772

73+
def atomic_add(i, pragmas, split=False):
74+
# Base case, real reduction
75+
if not split:
76+
return i._rebuild(pragmas=pragmas)
77+
# Complex reduction, split using a temp pointer
78+
# Transforns lhs += rhs into
79+
# {
80+
# float * lhs = reinterpret_cast<float*>(&lhs);
81+
# pragmas
82+
# lhs[0] += std::real(rhs);
83+
# pragmas
84+
# lhs[1] += std::imag(rhs);
85+
# }
86+
# Make a temp pointer
87+
lhs, rhs = i.expr.lhs, i.expr.rhs
88+
rdtype = lhs.dtype(0).real.__class__
89+
plhs = Pointer(name=f'p{lhs.name}', dtype=POINTER(dtype_to_ctype(rdtype)))
90+
peq = DummyExpr(plhs, cast(rdtype, stars='*')(Byref(lhs), reinterpret=True))
91+
92+
if (np.issubdtype(lhs.dtype, np.complexfloating)
93+
and np.issubdtype(rhs.dtype, np.complexfloating)):
94+
# Complex i, complex j
95+
# Atomic add real and imaginary parts separately
96+
lhsr, rhsr = IndexedPointer(plhs, 0), Real(rhs)
97+
lhsi, rhsi = IndexedPointer(plhs, 1), Imag(rhs)
98+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
99+
pragmas=pragmas)
100+
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
101+
pragmas=pragmas)
102+
return List(body=[peq, real, imag])
103+
104+
elif (np.issubdtype(lhs.dtype, np.complexfloating)
105+
and not np.issubdtype(rhs.dtype, np.complexfloating)):
106+
# Complex i, real j
107+
# Atomic add j to real part of i
108+
lhsr, rhsr = IndexedPointer(plhs, 0), rhs
109+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
110+
pragmas=pragmas)
111+
return List(body=[peq, real])
112+
else:
113+
# Real i, complex j
114+
raise InvalidOperator("Atomic add not implemented for real "
115+
"Functions with complex increments")
116+
117+
68118
class CXXBB(LangBB):
69119

70120
mapper = {
@@ -86,7 +136,7 @@ class CXXBB(LangBB):
86136
'host-free-pin': lambda i:
87137
Call('free', (i,)),
88138
'alloc-global-symbol': lambda i, j, k:
89-
Call('memcpy', (i, j, k)),
139+
Call('memcpy', (i, j, k))
90140
}
91141

92142

devito/passes/iet/languages/openacc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class AccBB(PragmaLangBB):
7474
Call('acc_set_device_num', args),
7575
# Pragmas
7676
'atomic':
77-
Pragma('acc atomic update'),
77+
lambda i, s: i._rebuild(pragmas=Pragma('acc atomic update')),
7878
'map-enter-to': lambda f, imask:
7979
PragmaTransfer('acc enter data copyin(%s%s)', f, imask=imask),
8080
'map-enter-to-async': lambda f, imask, a:

devito/passes/iet/languages/openmp.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sympy import And, Ne, Not
66

77
from devito.arch import AMDGPUX, NVIDIAX, INTELGPUX, PVC
8-
from devito.arch.compiler import GNUCompiler, NvidiaCompiler
8+
from devito.arch.compiler import GNUCompiler, NvidiaCompiler, CustomCompiler
99
from devito.ir import (Call, Conditional, DeviceCall, List, Pragma, Prodder,
1010
ParallelBlock, PointerCast, While, FindSymbols)
1111
from devito.passes.iet.definitions import DataManager, DeviceAwareDataManager
@@ -15,8 +15,8 @@
1515
PragmaDeviceAwareTransformer, PragmaLangBB,
1616
PragmaIteration, PragmaTransfer)
1717
from devito.passes.iet.languages.utils import joins
18-
from devito.passes.iet.languages.C import CBB
19-
from devito.passes.iet.languages.CXX import CXXBB
18+
from devito.passes.iet.languages.C import CBB, atomic_add as c_atomic_add
19+
from devito.passes.iet.languages.CXX import CXXBB, atomic_add as cxx_atomic_add
2020
from devito.symbolics import CondEq, DefFunction
2121
from devito.tools import filter_ordered
2222

@@ -133,8 +133,7 @@ class AbstractOmpBB(LangBB):
133133
Pragma('omp simd'),
134134
'simd-for-aligned': lambda n, *a:
135135
SimdForAligned('omp simd aligned(%s:%d)', arguments=(n, *a)),
136-
'atomic':
137-
Pragma('omp atomic update')
136+
'atomic': lambda i, s: i._rebuild(pragmas=Pragma('omp atomic update'))
138137
}
139138

140139
Region = OmpRegion
@@ -144,11 +143,20 @@ class AbstractOmpBB(LangBB):
144143

145144

146145
class OmpBB(AbstractOmpBB):
147-
mapper = {**AbstractOmpBB.mapper, **CBB.mapper}
146+
147+
mapper = {
148+
**AbstractOmpBB.mapper,
149+
**CBB.mapper,
150+
'atomic': lambda i, s: c_atomic_add(i, Pragma('omp atomic update'), split=s)
151+
}
148152

149153

150154
class CXXOmpBB(AbstractOmpBB):
151-
mapper = {**AbstractOmpBB.mapper, **CXXBB.mapper}
155+
mapper = {
156+
**AbstractOmpBB.mapper,
157+
**CXXBB.mapper,
158+
'atomic': lambda i, s: cxx_atomic_add(i, Pragma('omp atomic update'), split=s)
159+
}
152160

153161

154162
class DeviceOmpBB(OmpBB, PragmaLangBB):
@@ -230,6 +238,9 @@ class AbstractOmpizer(PragmaShmTransformer):
230238

231239
@classmethod
232240
def _support_array_reduction(cls, compiler):
241+
# In case we have a CustomCompiler
242+
if isinstance(compiler, CustomCompiler):
243+
compiler = compiler._base()
233244
# Not all backend compilers support array reduction!
234245
# Here are the known unsupported ones:
235246
if isinstance(compiler, GNUCompiler) and \
@@ -241,6 +252,16 @@ def _support_array_reduction(cls, compiler):
241252
else:
242253
return True
243254

255+
@classmethod
256+
def _support_complex_reduction(cls, compiler):
257+
# In case we have a CustomCompiler
258+
if isinstance(compiler, CustomCompiler):
259+
compiler = compiler._base()
260+
if isinstance(compiler, GNUCompiler):
261+
# Gcc doesn't supports complex reduction
262+
return False
263+
return True
264+
244265

245266
class Ompizer(AbstractOmpizer):
246267
langbb = OmpBB

devito/passes/iet/parpragma.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class PragmaSimdTransformer(PragmaTransformer):
4242
def _support_array_reduction(cls, compiler):
4343
return True
4444

45+
@classmethod
46+
def _support_complex_reduction(cls, compiler):
47+
return False
48+
4549
@property
4650
def simd_reg_nbytes(self):
4751
return self.platform.simd_reg_nbytes
@@ -238,8 +242,9 @@ def _make_reductions(self, partree):
238242
# Implement reduction
239243
mapper = {partree.root: partree.root._rebuild(reduction=reductions)}
240244
elif all(i is OpInc for _, _, i in reductions):
241-
# Use atomic increments
242-
mapper = {i: i._rebuild(pragmas=self.langbb['atomic']) for i in exprs}
245+
test2 = not self._support_complex_reduction(self.compiler) and \
246+
any(np.iscomplexobj(i.dtype(0)) for i, _, _ in reductions)
247+
mapper = {i: self.langbb['atomic'](i, test2) for i in exprs}
243248
else:
244249
raise NotImplementedError
245250

tests/test_dtypes.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@
22
import pytest
33
import sympy
44

5+
try:
6+
from ..conftest import skipif
7+
except ImportError:
8+
from conftest import skipif
9+
510
from devito import (
611
Constant, Eq, Function, Grid, Operator, exp, log, sin, configuration
712
)
13+
from devito.arch.compiler import GNUCompiler
14+
from devito.exceptions import InvalidOperator
815
from devito.ir.cgen.printer import BasePrinter
916
from devito.passes.iet.langbase import LangBB
1017
from devito.passes.iet.languages.C import CBB, CPrinter
1118
from devito.passes.iet.languages.openacc import AccBB, AccPrinter
1219
from devito.passes.iet.languages.openmp import OmpBB
1320
from devito.symbolics.extended_dtypes import ctypes_vector_mapper
21+
from devito.tools import dtype_to_cstr
1422
from devito.types.basic import Basic, Scalar, Symbol
1523
from devito.types.dense import TimeFunction
24+
from devito.types.sparse import SparseTimeFunction
1625

1726
# Mappers for language-specific types and headers
1827
_languages: dict[str, type[LangBB]] = {
@@ -274,3 +283,54 @@ def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None:
274283
dfdy = h.data.T[1:-1, 1:-1]
275284
assert np.allclose(dfdx, np.ones((5, 5), dtype=dtype))
276285
assert np.allclose(dfdy, np.ones((5, 5), dtype=dtype))
286+
287+
288+
@skipif(['noomp', 'device'])
289+
@pytest.mark.parametrize('dtypeu', [np.float32, np.complex64, np.complex128])
290+
def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None:
291+
"""
292+
Tests reductions over complex-valued functions.
293+
"""
294+
grid = Grid((11, 11))
295+
296+
u = TimeFunction(name="u", grid=grid, space_order=2, time_order=1, dtype=dtypeu)
297+
for dtypes in [dtypeu, dtypeu(0).real.__class__]:
298+
u.data.fill(0)
299+
s = SparseTimeFunction(name="s", grid=grid, npoint=1, nt=10, dtype=dtypes)
300+
if np.issubdtype(dtypes, np.complexfloating):
301+
s.data[:] = 1 + 2j
302+
expected = 8. + 16.j
303+
else:
304+
s.data[:] = 1
305+
expected = 8.
306+
s.coordinates.data[:] = [.5, .5]
307+
308+
# s complex and u real should error
309+
if np.issubdtype(dtypeu, np.floating) and \
310+
np.issubdtype(dtypes, np.complexfloating):
311+
with pytest.raises(InvalidOperator):
312+
op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s))
313+
continue
314+
else:
315+
op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s))
316+
op()
317+
318+
if op._options['linearize']:
319+
ustr = 'uL0(t1, rsx + posx + 2, rsy + posy + 2)'
320+
else:
321+
ustr = 'u[t1][rsx + posx + 2][rsy + posy + 2]'
322+
323+
if isinstance(configuration['compiler'], GNUCompiler) and \
324+
np.issubdtype(dtypeu, np.complexfloating):
325+
if 'CXX' in op._language:
326+
rd = dtype_to_cstr(dtypeu(0).real.__class__)
327+
assert f'{rd} * p{u.name} = reinterpret_cast<{rd}*>(&uL0' in str(op)
328+
assert f'p{u.name}[0] += std::real(r0)' in str(op)
329+
assert f'p{u.name}[1] += std::imag(r0)' in str(op)
330+
else:
331+
assert f'__real__ {ustr} += __real__ r0' in str(op)
332+
assert f'__imag__ {ustr} += __imag__ r0' in str(op)
333+
else:
334+
assert f'{ustr} += r0' in str(op)
335+
336+
assert np.isclose(u.data[0, 5, 5], expected)

0 commit comments

Comments
 (0)