1
+ from ctypes import POINTER
2
+
1
3
import numpy as np
2
4
from sympy .printing .cxx import CXX11CodePrinter
3
5
4
- from devito .ir import Call , UsingNamespace , BasePrinter
6
+ from devito import Real , Imag
7
+ from devito .exceptions import InvalidOperator
8
+ from devito .ir import Call , UsingNamespace , BasePrinter , DummyExpr , List
5
9
from devito .passes .iet .definitions import DataManager
6
10
from devito .passes .iet .orchestration import Orchestrator
7
11
from devito .passes .iet .langbase import LangBB
8
- from devito .symbolics import c_complex , c_double_complex
9
- from devito .tools import dtype_to_cstr
12
+ from devito .symbolics import c_complex , c_double_complex , IndexedPointer , cast , Byref
13
+ from devito .tools import dtype_to_cstr , dtype_to_ctype
14
+ from devito .types import Pointer
10
15
11
16
__all__ = ['CXXBB' , 'CXXDataManager' , 'CXXOrchestrator' ]
12
17
@@ -65,6 +70,51 @@ def std_arith(prefix=None):
65
70
"""
66
71
67
72
73
+ def atomic_add (i , pragmas , split = False ):
74
+ # Base case, real reduction
75
+ if not split :
76
+ return i ._rebuild (pragmas = pragmas )
77
+ # Complex reduction, split using a temp pointer
78
+ # Transforns lhs += rhs into
79
+ # {
80
+ # float * lhs = reinterpret_cast<float*>(&lhs);
81
+ # pragmas
82
+ # lhs[0] += std::real(rhs);
83
+ # pragmas
84
+ # lhs[1] += std::imag(rhs);
85
+ # }
86
+ # Make a temp pointer
87
+ lhs , rhs = i .expr .lhs , i .expr .rhs
88
+ rdtype = lhs .dtype (0 ).real .__class__
89
+ plhs = Pointer (name = f'p{ lhs .name } ' , dtype = POINTER (dtype_to_ctype (rdtype )))
90
+ peq = DummyExpr (plhs , cast (rdtype , stars = '*' )(Byref (lhs ), reinterpret = True ))
91
+
92
+ if (np .issubdtype (lhs .dtype , np .complexfloating )
93
+ and np .issubdtype (rhs .dtype , np .complexfloating )):
94
+ # Complex i, complex j
95
+ # Atomic add real and imaginary parts separately
96
+ lhsr , rhsr = IndexedPointer (plhs , 0 ), Real (rhs )
97
+ lhsi , rhsi = IndexedPointer (plhs , 1 ), Imag (rhs )
98
+ real = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsr , rhs = rhsr ),
99
+ pragmas = pragmas )
100
+ imag = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsi , rhs = rhsi ),
101
+ pragmas = pragmas )
102
+ return List (body = [peq , real , imag ])
103
+
104
+ elif (np .issubdtype (lhs .dtype , np .complexfloating )
105
+ and not np .issubdtype (rhs .dtype , np .complexfloating )):
106
+ # Complex i, real j
107
+ # Atomic add j to real part of i
108
+ lhsr , rhsr = IndexedPointer (plhs , 0 ), rhs
109
+ real = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsr , rhs = rhsr ),
110
+ pragmas = pragmas )
111
+ return List (body = [peq , real ])
112
+ else :
113
+ # Real i, complex j
114
+ raise InvalidOperator ("Atomic add not implemented for real "
115
+ "Functions with complex increments" )
116
+
117
+
68
118
class CXXBB (LangBB ):
69
119
70
120
mapper = {
@@ -86,7 +136,7 @@ class CXXBB(LangBB):
86
136
'host-free-pin' : lambda i :
87
137
Call ('free' , (i ,)),
88
138
'alloc-global-symbol' : lambda i , j , k :
89
- Call ('memcpy' , (i , j , k )),
139
+ Call ('memcpy' , (i , j , k ))
90
140
}
91
141
92
142
0 commit comments