18
18
import traceback
19
19
from typing import TYPE_CHECKING , List , Tuple
20
20
21
+ import paddle
21
22
from paddle .base .dygraph .base import sot_simulation_mode_guard
22
23
23
24
from ...profiler import EventGuard , event_register
47
48
48
49
GuardedFunction = Tuple [CustomCode , Guard ]
49
50
GuardedFunctions = List [GuardedFunction ]
51
+ GuardChain = List [paddle .framework .core .GuardNode ]
52
+ GuardChainList = List [GuardChain ]
50
53
51
54
dummy_guard : Guard = lambda frame : True
52
55
dummy_guard .expr = "lambda frame: True"
@@ -66,7 +69,9 @@ class OpcodeExecutorCache(metaclass=Singleton):
66
69
"""
67
70
68
71
MAX_CACHE_SIZE = 20
69
- cache : dict [types .CodeType , GuardedFunctions ]
72
+ cache : dict [
73
+ types .CodeType , tuple [GuardedFunctions , paddle .framework .core .GuardTree ]
74
+ ]
70
75
translate_count : int
71
76
code_symbolic_inputs : dict [types .CodeType , dict [str , None | dict [int , int ]]]
72
77
@@ -105,16 +110,25 @@ def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
105
110
code : types .CodeType = frame .f_code
106
111
if code not in self .cache :
107
112
log (2 , f"[Cache]: Firstly call { code } \n " )
108
- new_custom_code , guard_fn = self .translate (frame , ** kwargs )
113
+ new_custom_code , guard_fn , guard_chain = self .translate (
114
+ frame , ** kwargs
115
+ )
109
116
assert guard_fn is not None
110
- self .cache [code ] = [(new_custom_code , guard_fn )]
117
+ assert guard_chain is not None
118
+ self .cache [code ] = [
119
+ (new_custom_code , guard_fn )
120
+ ], paddle .framework .core .GuardTree ([guard_chain ])
111
121
return new_custom_code
112
- guarded_fns = self .cache [code ]
113
- return self .lookup (frame , guarded_fns , ** kwargs )
122
+ guarded_fns , guard_tree = self .cache [code ]
123
+ return self .lookup (frame , guarded_fns , guard_tree , ** kwargs )
114
124
115
125
@event_register ("lookup" )
116
126
def lookup (
117
- self , frame : types .FrameType , guarded_fns : GuardedFunctions , ** kwargs
127
+ self ,
128
+ frame : types .FrameType ,
129
+ guarded_fns : GuardedFunctions ,
130
+ guard_tree : paddle .framework .core .GuardTree ,
131
+ ** kwargs ,
118
132
) -> CustomCode :
119
133
"""
120
134
Looks up the cache for a matching code object and returns a custom code object if a matching guard function is found, otherwise None.
@@ -132,8 +146,17 @@ def lookup(
132
146
return CustomCode (None , False )
133
147
134
148
enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK .get ()
149
+ enable_guard_tree = ENV_SOT_ENABLE_GUARD_TREE .get ()
150
+
151
+ cache_index = None
152
+ if enable_strict_guard or enable_guard_tree :
153
+ cache_index = guard_tree .lookup (frame )
135
154
136
- for custom_code , guard_fn in guarded_fns :
155
+ if not enable_strict_guard and cache_index is not None :
156
+ # TODO(zrr1999): add a mapping between custom_code and cache_index
157
+ return guarded_fns [cache_index ][0 ]
158
+
159
+ for index , (custom_code , guard_fn ) in enumerate (guarded_fns ):
137
160
if enable_strict_guard :
138
161
mirror_guard_error = None
139
162
try :
@@ -157,9 +180,12 @@ def lookup(
157
180
2 ,
158
181
f"[Cache] Cache hit, Guard is \n { getattr (guard_fn , 'expr' , 'None' )} \n " ,
159
182
)
183
+ # TODO(zrr1999): cache_index should be equal to index when enable_strict_guard.
184
+ # assert (
185
+ # cache_index is None or index == cache_index
186
+ # ), f"cache_index({cache_index}) is not equal to index({index})"
160
187
return custom_code
161
- elif not ENV_SOT_ENABLE_GUARD_TREE .get ():
162
- # TODO(zrr1999): remove condition after faster guard tree support error analysis
188
+ else :
163
189
log_do (
164
190
4 ,
165
191
self .analyse_guard_global_object (guard_fn ),
@@ -192,9 +218,11 @@ def lookup(
192
218
)
193
219
194
220
log (2 , "[Cache]: all guards missed\n " )
195
- new_custom_code , guard_fn = self .translate (frame , ** kwargs )
221
+ new_custom_code , guard_fn , guard_chain = self .translate (frame , ** kwargs )
196
222
if guard_fn is not None :
223
+ assert guard_chain is not None
197
224
guarded_fns .append ((new_custom_code , guard_fn ))
225
+ guard_tree .add_guard_chain (guard_chain )
198
226
return new_custom_code
199
227
200
228
def before_translate_hook (self , frame : types .FrameType ):
@@ -203,7 +231,7 @@ def before_translate_hook(self, frame: types.FrameType):
203
231
204
232
def translate (
205
233
self , frame : types .FrameType , ** kwargs
206
- ) -> tuple [CustomCode , Guard | None ]:
234
+ ) -> tuple [CustomCode , Guard | None , GuardChain | None ]:
207
235
"""
208
236
Translates the given frame's code object and returns the cache getter function and a guarded function for the translated code object.
209
237
@@ -215,8 +243,10 @@ def translate(
215
243
"""
216
244
self .before_translate_hook (frame )
217
245
self .translate_count += 1
218
- custom_new_code , guard_fn = start_translate (frame , ** kwargs )
219
- return custom_new_code , guard_fn
246
+ custom_new_code , guard_fn , guard_chain = start_translate (
247
+ frame , ** kwargs
248
+ )
249
+ return custom_new_code , guard_fn , guard_chain
220
250
221
251
def analyse_guard_global_object (self , guard_fn ):
222
252
def inner ():
@@ -255,15 +285,15 @@ def inner():
255
285
def start_translate (
256
286
frame : types .FrameType ,
257
287
** kwargs ,
258
- ) -> tuple [CustomCode , Guard | None ]:
288
+ ) -> tuple [CustomCode , Guard | None , GuardChain | None ]:
259
289
"""
260
- Starts the translation process for the given frame and returns the translated code object and its guard function , or None if translation fails.
290
+ Starts the translation process for the given frame and returns the translated code object, its guard function and its guard tree node , or None if translation fails.
261
291
262
292
Args:
263
293
frame: The frame to be translated.
264
294
265
295
Returns:
266
- tuple[CustomCode, Guard | None]: The translated code object and its guard function , or None if translation fails.
296
+ tuple[CustomCode, Guard | None, GuardChain | None ]: The translated code object, its guard function and its guard tree node , or None if translation fails.
267
297
"""
268
298
graph = FunctionGraph (frame .f_code , frame .f_globals , ** kwargs )
269
299
vframe = VirtualFrame .from_real_frame (frame , graph )
@@ -280,8 +310,10 @@ def start_translate(
280
310
return (
281
311
CustomCode (None , True ),
282
312
None ,
313
+ None ,
283
314
)
284
- return new_custom_code , guard_fn
315
+ guard_chain = simulator .guard_chain
316
+ return new_custom_code , guard_fn , guard_chain
285
317
# TODO(0x45f): handle BreakGraphError to trigger fallback
286
318
except BreakGraphError as e :
287
319
raise RuntimeError (
@@ -299,9 +331,18 @@ def start_translate(
299
331
f"Unsupported Frame is { frame .f_code } , error message is: \n "
300
332
+ "" .join (traceback .format_exception (type (e ), e , e .__traceback__ )),
301
333
)
334
+
335
+ dummy_guard_chain = [
336
+ # TODO(zrr1999): GuardNode should support zero-expr constructor
337
+ paddle .framework .core .GuardNode (
338
+ paddle .framework .core .DummyGuard (),
339
+ [paddle .framework .core .ConstantExprNode (True )],
340
+ )
341
+ ]
302
342
return (
303
343
CustomCode (None , e .disable_eval_frame ),
304
344
dummy_guard ,
345
+ dummy_guard_chain ,
305
346
)
306
347
except Exception as e :
307
348
raise InnerError (OpcodeExecutorBase .error_message_summary (e )) from e
0 commit comments