Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit bc983ee

Browse files
27421957590x45f
and
0x45f
authored
add fallback wrapper for speed up (#100)
Co-authored-by: 0x45f <wangzhen45@baidu.com>
1 parent 66589ed commit bc983ee

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

symbolic_trace/symbolic/compile_cache.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,31 @@
44
from .interpreter import compile_sir
55

66

7+
def clear_eager_tensor_name(output_tensors):
8+
for output_tensor in output_tensors:
9+
output_tensor.name = ""
10+
11+
12+
class FallbackWrapper:
13+
def __init__(self, compile_sir):
14+
self.compile_sir = compile_sir
15+
self.partial_program_layer = None
16+
17+
def __call__(self, *args, **kwargs):
18+
frame_callback = paddle.fluid.core.set_eval_frame(None)
19+
if self.partial_program_layer is None:
20+
outputs = self.compile_sir(*args, **kwargs)
21+
self.partial_program_layer = self.compile_sir.get_concrete_program(
22+
*args, **kwargs
23+
)[1]
24+
else:
25+
# Speed up Resnet from 0.0068 --> 0.0057
26+
outputs = self.partial_program_layer(*args, **kwargs)
27+
clear_eager_tensor_name(outputs)
28+
paddle.fluid.core.set_eval_frame(frame_callback)
29+
return outputs
30+
31+
732
@Singleton
833
class CompileSIRCache(Cache):
934
def __init__(self):
@@ -16,6 +41,8 @@ def key_fn(self, context, sir_name):
1641
return hash_key
1742

1843
def value_fn(self, context, sir_name):
19-
return paddle.jit.to_static(
20-
compile_sir(context, sir_name), enable_fallback=False
44+
return FallbackWrapper(
45+
paddle.jit.to_static(
46+
compile_sir(context, sir_name), enable_fallback=False
47+
)
2148
)

0 commit comments

Comments
 (0)