diff --git a/sot/opcode_translator/executor/guard.py b/sot/opcode_translator/executor/guard.py index f0f8106c2..1766dc17d 100644 --- a/sot/opcode_translator/executor/guard.py +++ b/sot/opcode_translator/executor/guard.py @@ -33,11 +33,11 @@ class StringifyExpression: Used to store string based expressions for generating Guard. """ - def __init__(self, str_expr, format_args, free_vars): - expr = str_expr.format(*[arg.expr for arg in format_args]) + def __init__(self, str_expr, sub_exprs, free_vars): + expr = str_expr.format(*[arg.expr for arg in sub_exprs]) self.expr = current_tmp_name_records().add_tmp_var(expr) self.debug_expr = str_expr.format( - *[arg.debug_expr for arg in format_args] + *[arg.debug_expr for arg in sub_exprs] ) self.free_vars = free_vars @@ -92,7 +92,7 @@ def analyse_expresions(stringify_exprs, tmp_names): lambda_string += str_expr.debug_expr + " and " free_vars = union_free_vars(free_vars, str_expr.free_vars) - func_string += f" return {func_result[:-5]}\n" + func_string += f" return {func_result[:-5]}" return func_string, free_vars, lambda_string[:-5] diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 324db5672..f887a6c0b 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -174,7 +174,7 @@ def lookup( if guard_result: log( 2, - f"[Cache]: Cache hit, Guard is {getattr(guard_fn, 'expr', 'None')}\n", + f"[Cache]: Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", ) return custom_code else: @@ -184,7 +184,7 @@ def lookup( ) log( 2, - f"[Cache]: Cache miss, Guard is {getattr(guard_fn, 'expr', 'None')}\n", + f"[Cache]: Cache miss, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", ) log_do( 2, diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index b17d6bbb2..70b55f16f 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -47,44 +47,47 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode: return CustomCode(None, True) if need_skip(frame): - return CustomCode(None, False) - - log( - 2, - "[eval_frame_callback] start to translate: " - + str(frame.f_code) - + "\n", - ) - log_do(4, partial(print_locals, frame)) - - log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") - log_do(3, lambda: dis.dis(frame.f_code)) - - custom_code = InstructionTranslatorCache()(frame, **kwargs) - - if custom_code.code is None: - log( - 3, - "[transform] NewCode (same as origin code): " - + frame.f_code.co_name - + "\n", - ) - used_code = frame.f_code + log(3, f"[eval_frame_callback] skip {frame.f_code}\n") + custom_code = CustomCode(None, False) + new_code = frame.f_code else: log( - 3, - "[transform] NewCode: " + custom_code.code.co_name + "\n", + 2, + "[eval_frame_callback] start to translate: " + + str(frame.f_code) + + "\n", ) - log_do(3, lambda: dis.dis(custom_code.code)) - used_code = custom_code.code + log_do(4, partial(print_locals, frame)) + + log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") + log_do(3, lambda: dis.dis(frame.f_code)) + + custom_code = InstructionTranslatorCache()(frame, **kwargs) + + if custom_code.code is None: + log( + 3, + "[transform] NewCode (same as origin code): " + + frame.f_code.co_name + + "\n", + ) + new_code = frame.f_code + else: + log( + 3, + "[transform] NewCode: " + custom_code.code.co_name + "\n", + ) + log_do(3, lambda: dis.dis(custom_code.code)) + new_code = custom_code.code # just check those codes which need open eval_frame - if custom_code.disable_eval_frame is False and CodeStatus().check_code( - used_code + if ( + custom_code.disable_eval_frame is False + and CodeStatus().is_code_without_graph(new_code) ): log( 3, - "[transform] Code has found no graph, block it.", + "[eval_frame_callback] Code has no graph, block it.", ) return CustomCode(None, True) diff --git a/sot/utils/code_status.py b/sot/utils/code_status.py index aa498de17..9b4438288 100644 --- a/sot/utils/code_status.py +++ b/sot/utils/code_status.py @@ -1,7 +1,9 @@ import inspect from enum import Enum -from .utils import Singleton +import paddle + +from .utils import Singleton, log class CodeState(Enum): @@ -21,13 +23,26 @@ def __repr__(self): @Singleton class CodeStatus: + WITH_GRAPH_API = [ + paddle.nn.Layer.__call__.__code__, + paddle.nn.Layer._dygraph_call_func.__code__, + ] + def __init__(self): self.code_map = {} + self.setup_code_map() + + def setup_code_map(self): + for code in self.WITH_GRAPH_API: + info = CodeInfo() + info.state = CodeState.WITH_GRAPH + self.code_map[code] = info def clear(self): self.code_map.clear() + self.setup_code_map() - def check_code(self, code): + def is_code_without_graph(self, code): if code not in self.code_map: info = CodeInfo() self.code_map[code] = info @@ -36,16 +51,16 @@ def check_code(self, code): if info.state == CodeState.WITHOUT_GRAPH: return True - elif info.state == CodeState.UNKNOW: - self.visit(code) + if info.state == CodeState.UNKNOW: + info.counter += 1 + if info.counter >= 10: + log( + 3, + f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", + ) + info.state = CodeState.WITHOUT_GRAPH return False - def visit(self, code): - info = self.code_map[code] - info.counter += 1 - if info.state == CodeState.UNKNOW and info.counter > 10: - info.state = CodeState.WITHOUT_GRAPH - def trace_back_frames(self): frame = inspect.currentframe() while frame.f_back is not None: @@ -53,4 +68,9 @@ def trace_back_frames(self): code = frame.f_code if code in self.code_map: info = self.code_map[code] - info.state = CodeState.WITH_GRAPH + if info.state != CodeState.WITH_GRAPH: + log( + 3, + f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", + ) + info.state = CodeState.WITH_GRAPH diff --git a/tests/test_code_info.py b/tests/test_code_info.py deleted file mode 100644 index 52cc58769..000000000 --- a/tests/test_code_info.py +++ /dev/null @@ -1,79 +0,0 @@ -import unittest - -from test_case_base import TestCaseBase, strict_mode_guard - -import paddle -import sot -from sot.utils import CodeStatus - - -class SimpleNet1(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] - ) - - def forward(self, x): - for i in range(len(self.layers)): - sot.psdb.breakgraph() - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - return x - - -class SimpleNet2(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] - ) - - def forward(self, x): - sot.psdb.fallback() - for i in range(len(self.layers)): - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - return x - - -def run_net(net, x): - for i in range(20): - x = net(x) - return x - - -class TestCodeInfo(TestCaseBase): - def _analyse_code_info(self, code_map): - return {k.co_name: str(v.state) for k, v in code_map.items()} - - def test_case_1(self): - CodeStatus().clear() - net = SimpleNet1() - inp = paddle.rand((10, 10)) - self.assert_results(run_net, net, inp) - code_infos = self._analyse_code_info(CodeStatus().code_map) - states = list(code_infos.values()) - # run_net, forward, loop body, resumed part2 in loop body - assert len([v for v in states if v == "CodeState.WITH_GRAPH"]) == 4 - # resumed part1 in loop body - assert len([v for v in states if v == "CodeState.WITHOUT_GRAPH"]) == 1 - - def test_case_2(self): - with strict_mode_guard(0): - CodeStatus().clear() - net = SimpleNet2() - inp = paddle.rand((10, 10)) - self.assert_results(run_net, net, inp) - code_infos = self._analyse_code_info(CodeStatus().code_map) - states = list(code_infos.values()) - # no graph found because fallback (paddle api will not enter simulate) - assert len([v for v in states if v == "CodeState.WITH_GRAPH"]) == 0 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_code_status.py b/tests/test_code_status.py new file mode 100644 index 000000000..4a24f3050 --- /dev/null +++ b/tests/test_code_status.py @@ -0,0 +1,140 @@ +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +import sot +from sot.opcode_translator.skip_files import skip_function +from sot.utils.code_status import CodeState, CodeStatus + + +class SimpleNet1(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + for i in range(len(self.layers)): + sot.psdb.breakgraph() + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +class SimpleNet2(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + sot.psdb.fallback() + for i in range(len(self.layers)): + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +def run_net(net, x): + for i in range(20): + x = net(x) + return x + + +class TestCodeInfo(TestCaseBase): + def test_case_1(self): + CodeStatus().clear() + net = SimpleNet1() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # run_net, forward, loop body, resumed part2 in loop body + assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 + # resumed part1 in loop body + assert ( + len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 + ) + + def test_case_2(self): + with strict_mode_guard(0): + CodeStatus().clear() + net = SimpleNet2() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # no graph found because fallback (paddle api will not enter simulate) + assert ( + len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 + ) + + +def no_skip_func_0(x): + return x + 1 + + +def skipped_func_0(): + pass + + +def skipped_func_1(x): + return x + 1 + + +def skipped_func_2(x): + return no_skip_func_0(x) + + +def call_skipped_func_0(x): + for i in range(15): + skipped_func_0() + x = skipped_func_1(x) + x = skipped_func_2(x) + return x + + +skip_function(skipped_func_0) +skip_function(skipped_func_1) +skip_function(skipped_func_2) +skip_function(call_skipped_func_0) + + +class TestDisableSkippedFrame(TestCaseBase): + def test_case_0(self): + CodeStatus().clear() + x = paddle.to_tensor([1]) + self.assert_results(call_skipped_func_0, x) + code_map = CodeStatus().code_map + assert ( + code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert ( + code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH + + +if __name__ == "__main__": + unittest.main()