From 0a86f0e69a94b89e42b68977267fbcbf72a2a336 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 24 Apr 2025 20:21:34 +0800 Subject: [PATCH] [SOT][CINN] Fallback if CINN compile is too slow --- .../jit/dy2static/pir_partial_program.py | 7 ++- python/paddle/jit/dy2static/utils.py | 22 +++++++ .../jit/sot/opcode_translator/breakpoint.py | 2 +- .../executor/executor_cache.py | 60 ++++++++++++++++--- .../executor/function_graph.py | 2 + .../sot/opcode_translator/executor/guard.py | 8 +-- .../executor/opcode_executor.py | 6 +- .../executor/pycode_generator.py | 23 ------- .../executor/variables/basic.py | 6 +- .../executor/variables/container.py | 24 ++++---- .../paddle/jit/sot/symbolic/compile_cache.py | 18 +++++- python/paddle/jit/sot/utils/__init__.py | 2 + python/paddle/jit/sot/utils/info_collector.py | 8 ++- python/paddle/jit/sot/utils/utils.py | 17 ++++++ 14 files changed, 144 insertions(+), 61 deletions(-) diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index cd2d1de9234605..02e24667c75a55 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -35,6 +35,7 @@ from .utils import ( RETURN_NO_VALUE_MAGIC_NUM, Backend, + TimeCounter, auto_layout_is_enabled, backend_guard, cse_is_enabled, @@ -723,6 +724,8 @@ def __init__( self._backend = kwargs.get('backend', Backend.PHI) self._grad_var_names = {} + self._compile_time_counter = TimeCounter() + def __call__(self, inputs): """ Execute static graph by Interpreter and Return dynamic Tensors. @@ -974,12 +977,12 @@ def program_id(self): @cached_property def train_program(self) -> RunnableProgram: - with backend_guard(self._backend): + with backend_guard(self._backend), self._compile_time_counter.record(): return self._create_program() @cached_property def infer_program(self) -> RunnableProgram: - with backend_guard(self._backend): + with backend_guard(self._backend), self._compile_time_counter.record(): return self._create_program(is_infer_mode=True) def _verify_program(self, main_program, outputs): diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 86eca6f5ec1d55..832b16bc412ec9 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -25,6 +25,7 @@ import sys import tempfile import textwrap +import time import types import warnings from contextlib import contextmanager @@ -108,6 +109,27 @@ def is_phi(self): return self == Backend.PHI +class TimeCounter: + def __init__(self): + self._time_history: list[float] = [] + + def get_last_time(self): + if len(self._time_history) == 0: + return 0 + return self._time_history[-1] + + def get_total_time(self): + return sum(self._time_history) + + @contextmanager + def record(self): + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + elapsed_time = end_time - start_time + self._time_history.append(elapsed_time) + + def data_layer_not_check(name, shape, dtype='float32'): """ This function creates a Tensor on the global block. The created Tensor diff --git a/python/paddle/jit/sot/opcode_translator/breakpoint.py b/python/paddle/jit/sot/opcode_translator/breakpoint.py index 4adc6039e08919..e3be2f736bd9e0 100644 --- a/python/paddle/jit/sot/opcode_translator/breakpoint.py +++ b/python/paddle/jit/sot/opcode_translator/breakpoint.py @@ -99,7 +99,7 @@ def opcode(self, cur_exe=None): if cur_exe is None: cur_exe = self.cur_exe instr = cur_exe._instructions[cur_exe.vframe.lasti - 1] - message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" + message = f"[Translate {cur_exe}] (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" return message def bt(self): diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py index e0a8bede116842..7a56e5bad36d33 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -37,6 +37,7 @@ is_strict_mode, log, log_do, + log_once, ) from ..custom_code import CustomCode from .function_graph import FunctionGraph @@ -70,16 +71,20 @@ class OpcodeExecutorCache(metaclass=Singleton): """ MAX_CACHE_SIZE = 20 + MAX_COMPILE_TIME_PER_CODE = 40 + MAX_COMPILE_TIME_TOTAL = 15 * 60 cache: dict[ types.CodeType, tuple[GuardedFunctions, paddle.framework.core.GuardTree] ] translate_count: int code_symbolic_inputs: dict[types.CodeType, dict[str, None | dict[int, int]]] + compile_time_stats: dict[types.CodeType, float] def __init__(self): self.cache = {} self.translate_count = 0 self.code_symbolic_inputs = {} + self.compile_time_stats = {} def get_symbolic_inputs( self, code: types.CodeType @@ -94,23 +99,26 @@ def clear(self): self.cache.clear() self.translate_count = 0 self.code_symbolic_inputs.clear() + self.compile_time_stats.clear() def dump_state(self): return { "cache": self.cache, "translate_count": self.translate_count, "code_symbolic_inputs": self.code_symbolic_inputs, + "compile_time_stats": self.compile_time_stats, } def load_state(self, state): self.cache = state["cache"] self.translate_count = state["translate_count"] self.code_symbolic_inputs = state["code_symbolic_inputs"] + self.compile_time_stats = state["compile_time_stats"] def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: code: types.CodeType = frame.f_code if code not in self.cache: - log(2, f"[Cache]: Firstly call {code}\n") + log(2, f"[Cache] Firstly call {code}\n") new_custom_code, guard_fn, guard_chain = self.translate( frame, **kwargs ) @@ -121,7 +129,16 @@ def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: ], paddle.framework.core.GuardTree([guard_chain]) return new_custom_code guarded_fns, guard_tree = self.cache[code] - return self.lookup(frame, guarded_fns, guard_tree, **kwargs) + compile_time_for_code = self.compile_time_stats.get(code, 0) + compile_time_total = sum(self.compile_time_stats.values()) + return self.lookup( + frame, + guarded_fns, + guard_tree, + compile_time_for_code, + compile_time_total, + **kwargs, + ) @event_register("lookup") def lookup( @@ -129,6 +146,8 @@ def lookup( frame: types.FrameType, guarded_fns: GuardedFunctions, guard_tree: paddle.framework.core.GuardTree, + compile_time_for_code: float, + compile_time_total: float, **kwargs, ) -> CustomCode: """ @@ -143,7 +162,7 @@ def lookup( """ if len(guarded_fns) >= self.MAX_CACHE_SIZE: - log(2, "[Cache]: Exceed max cache size, skip it\n") + log(2, "[Cache] Exceed max cache size, skip it\n") return CustomCode(None, False) enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK.get() @@ -159,7 +178,22 @@ def lookup( # TODO(zrr1999): add a mapping between custom_code and cache_index return guarded_fns[cache_index][0] else: - log(2, "[Cache]: all guards missed (guard tree mode)\n") + log(2, "[Cache] all guards missed (guard tree mode)\n") + if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE: + log( + 2, + "[Cache] Exceed max compile time per code, skip it\n", + ) + return CustomCode(None, False) + if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL: + log_once( + f"[SOT] Current total compile time is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph" + ) + log( + 2, + "[Cache] Exceed max compile time total, skip it\n", + ) + return CustomCode(None, False) new_custom_code, guard_fn, guard_chain = self.translate( frame, **kwargs ) @@ -230,7 +264,19 @@ def lookup( f"mirror_guard_error: {mirror_guard_error}," ) - log(2, "[Cache]: all guards missed\n") + log(2, "[Cache] all guards missed\n") + if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE: + log(2, "[Cache] Exceed max compile time per code, skip it\n") + return CustomCode(None, False) + if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL: + log_once( + f"[SOT] Current compile time total is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph" + ) + log( + 2, + "[Cache] Exceed max compile time total, skip it\n", + ) + return CustomCode(None, False) new_custom_code, guard_fn, guard_chain = self.translate(frame, **kwargs) if guard_fn is not None: assert guard_chain is not None @@ -288,9 +334,9 @@ def inner(): ) return if result is False: - print(f"[Cache]: missed at {guard_str}") + print(f"[Cache] missed at {guard_str}") return - print("[Cache]: missed guard not found.") + print("[Cache] missed guard not found.") return inner diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 8cb71efb140ec7..2b0ba139426f9e 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -45,6 +45,7 @@ ENV_SOT_ENABLE_GUARD_TREE, ENV_SOT_ENABLE_STRICT_GUARD_CHECK, NameGenerator, + SIRToCodeMap, SotUndefinedVar, inner_error_default_handler, is_inplace_api, @@ -462,6 +463,7 @@ def compile_graph(self, *ret_vars: VariableBase) -> CompileGraphResult: OrderedSet(), OrderedSet(), ) + SIRToCodeMap().register(statement_ir, self.pycode_gen._origin_code) input_names = statement_ir.inputs symbolic_inputs = self._find_tensor_inputs(input_names) compiled_fn = self.sir_ctx.compile_fn( diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py index 8e85c37eff3187..5c21c3667ccf28 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/guard.py +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -111,7 +111,7 @@ def __init__( ) log( 3, - f"[FasterGuard]: transform {original_expr_template} to {expr_template}\n", + f"[FasterGuard] transform {original_expr_template} to {expr_template}\n", ) super().__init__(expr_template, sub_exprs, free_vars) @@ -182,7 +182,7 @@ def make_guard(stringified_guards: list[StringifiedExpression]) -> Guard: guard = eval(guard_expr, free_vars) - log(3, f"[Guard]: {inlined_guard_expr}\n") + log(3, f"[Guard] {inlined_guard_expr}\n") guard.inlined_expr = inlined_guard_expr guard.expr = guard_expr @@ -231,7 +231,7 @@ def wrapper(self: CheckGuardInputT) -> list[StringifiedExpression]: def guard_log(): frame_value_tracer = self.tracker.trace_value_from_frame() print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}" + f"[Guard] guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}" ) log_do(4, guard_log) @@ -253,7 +253,7 @@ def wrapper( def guard_log(): frame_value_tracer = self.tracker.trace_value_from_frame() print( - f"[Guard Tree]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}" + f"[Guard Tree] guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}" ) log_do(4, guard_log) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index fafaa02200fdd1..a75f8fcc46dd06 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -626,7 +626,7 @@ def step(self, instr: Instruction): self._current_line = instr.starts_line if not hasattr(self, instr.opname): raise FallbackError(f"opcode: {instr.opname} is not supported.") - log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n" + log_message = f"[Translate {self._name}] (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n" log(3, log_message) code_file = self.vframe.code.co_filename code_line = self._current_line @@ -2440,7 +2440,7 @@ def create_loop_body(): log( 3, - "[Resumed Function]: break graph in loop create loop body as\n", + "[Resumed Function] break graph in loop create loop body as\n", ) log_do(3, lambda: dis.dis(loop_body_fn)) @@ -2685,7 +2685,7 @@ def create_inline_call_fn(): log( 3, - f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n", + f"[Resumed Function] Inline call for loop function {inline_call_fn.__code__.co_name}\n", ) log_do(3, lambda: dis.dis(inline_call_fn)) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 351920707423dc..3c689c05bb71be 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -946,29 +946,6 @@ def gen_return(self): def gen_get_iter(self): return self.add_instr("GET_ITER") - def gen_operator_only(self, op_name): - """ - only generator operator instruction, do nothing for - operands. - """ - return self.add_instr(op_name) - - def gen_operator(self, op_name): - """ - only generator operator instruction, do nothing for - operands. - """ - return self.add_instr(op_name) - - def gen_compare(self, cmp_op): - """ - only generator operator instruction, do nothing for - operands. - """ - if sys.version_info >= (3, 12): - cmp_op <<= 4 - return self.add_instr("COMPARE_OP", cmp_op) - def add_instr(self, *args, **kwargs): instr = gen_instr(*args, **kwargs) self._instructions.append(instr) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index dbfcfa952cc9f6..cbe4b4cbc24928 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -1947,18 +1947,18 @@ def keys(self): def get(self, key): if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} to get value." + f"[{self.__class__.__name__}] received {key} to get value." ) return self.proxy.get(key) def set(self, key, value): if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." + f"[{self.__class__.__name__}] received {key} as key." ) if not isinstance(value, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {value} to set value." + f"[{self.__class__.__name__}] received {value} to set value." ) self.proxy.set(key, value) self.graph.side_effects.record_proxy_variable(self) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py index 05acb98e7533ec..26b8bbc012ba05 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py @@ -260,7 +260,7 @@ def getitem(self, key): def setitem(self, key, value): if not isinstance(value, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {value} to set value." + f"[{self.__class__.__name__}] received {value} to set value." ) if isinstance(key, int): self.proxy.set(key, value) @@ -297,7 +297,7 @@ def __delitem__(self, key): def delitem(self, key): if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key to delete." + f"[{self.__class__.__name__}] received {key} as key to delete." ) self.proxy.delete(key) self.graph.side_effects.record_proxy_variable(self) @@ -601,17 +601,13 @@ def getitem(self, key): ) def setitem(self, key, value): - raise InnerError( - f"[{self.__class__.__name__}]: setitem is not allowed." - ) + raise InnerError(f"[{self.__class__.__name__}] setitem is not allowed.") def __delitem__(self, key): return self.delitem(key) def delitem(self, key): - raise InnerError( - f"[{self.__class__.__name__}]: delitem is not allowed." - ) + raise InnerError(f"[{self.__class__.__name__}] delitem is not allowed.") def concat(self, tuple_): assert isinstance(tuple_, TupleVariable) @@ -850,7 +846,7 @@ def _reconstruct(self, codegen: PyCodeGen): for key in self.proxy.get_all().keys(): if not isinstance(key, ConstTypes): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." + f"[{self.__class__.__name__}] received {key} as key." ) key_var = ConstantVariable.wrap_literal(key, self.graph) value_var = self[key] @@ -879,7 +875,7 @@ def get_wrapped_items(self): for key in self.proxy.get_all().keys(): if not isinstance(key, ConstTypes): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." + f"[{self.__class__.__name__}] received {key} as key." ) items[key] = self[key] return items @@ -908,7 +904,7 @@ def get(self, key, default=None): self.graph.add_global_guarded_variable(self) if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} to get value." + f"[{self.__class__.__name__}] received {key} to get value." ) if default is None: @@ -928,12 +924,12 @@ def getitem(self, key): def setitem(self, key, value): if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." + f"[{self.__class__.__name__}] received {key} as key." ) if not isinstance(value, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {value} to set value." + f"[{self.__class__.__name__}] received {value} to set value." ) self.proxy.set(key, value) @@ -954,7 +950,7 @@ def __delitem__(self, key): def delitem(self, key): if isinstance(key, VariableBase): raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key to delete." + f"[{self.__class__.__name__}] received {key} as key to delete." ) self.proxy.delete(key) self.graph.side_effects.record_proxy_variable(self) diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index 98e0af04149f84..c92f3272cb4772 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -30,6 +30,7 @@ InfoCollector, NewSymbolHitRateInfo, Singleton, + SIRToCodeMap, StepInfoManager, SubGraphInfo, SubGraphRelationInfo, @@ -90,7 +91,6 @@ def allocate(self, tensor): if not hasattr(tensor, self.TENSOR_ID_ATTR): setattr(tensor, self.TENSOR_ID_ATTR, self._id_generator()) return getattr(tensor, self.TENSOR_ID_ATTR) - # return tensor._get_tensor_ptr() class FallbackWrapper: @@ -228,6 +228,21 @@ def collect_subgraph_info(self, program: Program): self.SIR.name, ) + def update_compile_time_info(self, SIR, partial_program_layer): + if not self.is_first_call: + return + from ..opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, + ) + + code = SIRToCodeMap().get(SIR) + assert code is not None, f"Cannot find code for SIR: {SIR}" + + OpcodeExecutorCache().compile_time_stats.setdefault(code, 0) + OpcodeExecutorCache().compile_time_stats[ + code + ] += partial_program_layer._compile_time_counter.get_total_time() + def __call__(self, *args, **kwargs): with EventGuard(f"FallbackWrapper: {self.SIR.name}"): if StepInfoManager().need_back_trace: @@ -265,6 +280,7 @@ def __call__(self, *args, **kwargs): self.collect_new_symbol_hit_rate(args, outputs) self.collect_subgraph_relation(args, outputs, self.partial_program) self.collect_subgraph_info(self.concrete_program.main_program) + self.update_compile_time_info(self.SIR, self.partial_program) if ENV_SOT_EXPORT.get() != "" and not self.exported: export(self.SIR, ENV_SOT_EXPORT.get()) self.exported = True diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index effce738dc25ad..468611bf4c3f38 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -75,6 +75,7 @@ NameGenerator, ResumeFnNameFactory, Singleton, + SIRToCodeMap, SotUndefinedVar, StepInfoManager, count_if, @@ -100,6 +101,7 @@ log_do, log_enabled, log_format, + log_once, map_if, map_if_extend, meta_str, diff --git a/python/paddle/jit/sot/utils/info_collector.py b/python/paddle/jit/sot/utils/info_collector.py index fd6a83f4666c16..9e36c785ac2567 100644 --- a/python/paddle/jit/sot/utils/info_collector.py +++ b/python/paddle/jit/sot/utils/info_collector.py @@ -391,7 +391,9 @@ def __init__(self, graph: str, op_num: int, sir_name: str): self.sir_name = sir_name def __str__(self): - return f"[SIR Name]: {self.sir_name} [OpNum]: {self.op_num}\n{self.graph}" + return ( + f"[SIR Name] {self.sir_name} [OpNum] {self.op_num}\n{self.graph}" + ) @classmethod def summary(cls, history: list[Self]) -> str: @@ -406,12 +408,12 @@ def summary(cls, history: list[Self]) -> str: if need_details: details = "\n".join( [ - f"[SubGraphIdx]: {idx} {info}" + f"[SubGraphIdx] {idx} {info}" for idx, info in enumerate(map(str, history)) ] ) - summary = f"[Number of subgraph]: {num_of_subgraph} [Sum of opnum]: {sum_of_op_num}" + summary = f"[Number of subgraph] {num_of_subgraph} [Sum of opnum] {sum_of_op_num}" return f"{summary}\n{details}" diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 4d00862f4f558c..d43c0bb2509e55 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -23,6 +23,7 @@ import weakref from collections import OrderedDict from contextlib import contextmanager +from functools import lru_cache from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary @@ -127,6 +128,17 @@ def next(self): return name +class SIRToCodeMap(metaclass=Singleton): + def __init__(self): + self._map = {} + + def register(self, sir, code): + self._map[sir.name] = code + + def get(self, sir): + return self._map.get(sir.name) + + def log(level, *args): cur_level = ENV_SOT_LOG_LEVEL.get() if level <= cur_level: @@ -149,6 +161,11 @@ def log_enabled(level): return level <= ENV_SOT_LOG_LEVEL.get() +@lru_cache +def log_once(msg): + print(msg, flush=True) + + def no_eval_frame(func): def no_eval_frame_func(*args, **kwargs): old_cb = paddle.framework.core.set_eval_frame(None)