Skip to content

[SOT][CINN] Fallback if CINN compile is too slow #72462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
Backend,
TimeCounter,
auto_layout_is_enabled,
backend_guard,
cse_is_enabled,
Expand Down Expand Up @@ -723,6 +724,8 @@ def __init__(
self._backend = kwargs.get('backend', Backend.PHI)
self._grad_var_names = {}

self._compile_time_counter = TimeCounter()

def __call__(self, inputs):
"""
Execute static graph by Interpreter and Return dynamic Tensors.
Expand Down Expand Up @@ -974,12 +977,12 @@ def program_id(self):

@cached_property
def train_program(self) -> RunnableProgram:
with backend_guard(self._backend):
with backend_guard(self._backend), self._compile_time_counter.record():
return self._create_program()

@cached_property
def infer_program(self) -> RunnableProgram:
with backend_guard(self._backend):
with backend_guard(self._backend), self._compile_time_counter.record():
return self._create_program(is_infer_mode=True)

def _verify_program(self, main_program, outputs):
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import tempfile
import textwrap
import time
import types
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -108,6 +109,27 @@ def is_phi(self):
return self == Backend.PHI


class TimeCounter:
def __init__(self):
self._time_history: list[float] = []

def get_last_time(self):
if len(self._time_history) == 0:
return 0
return self._time_history[-1]

def get_total_time(self):
return sum(self._time_history)

@contextmanager
def record(self):
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
elapsed_time = end_time - start_time
self._time_history.append(elapsed_time)


def data_layer_not_check(name, shape, dtype='float32'):
"""
This function creates a Tensor on the global block. The created Tensor
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/opcode_translator/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def opcode(self, cur_exe=None):
if cur_exe is None:
cur_exe = self.cur_exe
instr = cur_exe._instructions[cur_exe.vframe.lasti - 1]
message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n"
message = f"[Translate {cur_exe}] (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n"
return message

def bt(self):
Expand Down
60 changes: 53 additions & 7 deletions python/paddle/jit/sot/opcode_translator/executor/executor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_strict_mode,
log,
log_do,
log_once,
)
from ..custom_code import CustomCode
from .function_graph import FunctionGraph
Expand Down Expand Up @@ -70,16 +71,20 @@ class OpcodeExecutorCache(metaclass=Singleton):
"""

MAX_CACHE_SIZE = 20
MAX_COMPILE_TIME_PER_CODE = 40
MAX_COMPILE_TIME_TOTAL = 15 * 60
cache: dict[
types.CodeType, tuple[GuardedFunctions, paddle.framework.core.GuardTree]
]
translate_count: int
code_symbolic_inputs: dict[types.CodeType, dict[str, None | dict[int, int]]]
compile_time_stats: dict[types.CodeType, float]

def __init__(self):
self.cache = {}
self.translate_count = 0
self.code_symbolic_inputs = {}
self.compile_time_stats = {}

def get_symbolic_inputs(
self, code: types.CodeType
Expand All @@ -94,23 +99,26 @@ def clear(self):
self.cache.clear()
self.translate_count = 0
self.code_symbolic_inputs.clear()
self.compile_time_stats.clear()

def dump_state(self):
return {
"cache": self.cache,
"translate_count": self.translate_count,
"code_symbolic_inputs": self.code_symbolic_inputs,
"compile_time_stats": self.compile_time_stats,
}

def load_state(self, state):
self.cache = state["cache"]
self.translate_count = state["translate_count"]
self.code_symbolic_inputs = state["code_symbolic_inputs"]
self.compile_time_stats = state["compile_time_stats"]

def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
code: types.CodeType = frame.f_code
if code not in self.cache:
log(2, f"[Cache]: Firstly call {code}\n")
log(2, f"[Cache] Firstly call {code}\n")
new_custom_code, guard_fn, guard_chain = self.translate(
frame, **kwargs
)
Expand All @@ -121,14 +129,25 @@ def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
], paddle.framework.core.GuardTree([guard_chain])
return new_custom_code
guarded_fns, guard_tree = self.cache[code]
return self.lookup(frame, guarded_fns, guard_tree, **kwargs)
compile_time_for_code = self.compile_time_stats.get(code, 0)
compile_time_total = sum(self.compile_time_stats.values())
return self.lookup(
frame,
guarded_fns,
guard_tree,
compile_time_for_code,
compile_time_total,
**kwargs,
)

@event_register("lookup")
def lookup(
self,
frame: types.FrameType,
guarded_fns: GuardedFunctions,
guard_tree: paddle.framework.core.GuardTree,
compile_time_for_code: float,
compile_time_total: float,
**kwargs,
) -> CustomCode:
"""
Expand All @@ -143,7 +162,7 @@ def lookup(
"""

if len(guarded_fns) >= self.MAX_CACHE_SIZE:
log(2, "[Cache]: Exceed max cache size, skip it\n")
log(2, "[Cache] Exceed max cache size, skip it\n")
return CustomCode(None, False)

enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK.get()
Expand All @@ -159,7 +178,22 @@ def lookup(
# TODO(zrr1999): add a mapping between custom_code and cache_index
return guarded_fns[cache_index][0]
else:
log(2, "[Cache]: all guards missed (guard tree mode)\n")
log(2, "[Cache] all guards missed (guard tree mode)\n")
if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE:
log(
2,
"[Cache] Exceed max compile time per code, skip it\n",
)
return CustomCode(None, False)
if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL:
log_once(
f"[SOT] Current total compile time is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph"
)
log(
2,
"[Cache] Exceed max compile time total, skip it\n",
)
return CustomCode(None, False)
new_custom_code, guard_fn, guard_chain = self.translate(
frame, **kwargs
)
Expand Down Expand Up @@ -230,7 +264,19 @@ def lookup(
f"mirror_guard_error: {mirror_guard_error},"
)

log(2, "[Cache]: all guards missed\n")
log(2, "[Cache] all guards missed\n")
if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE:
log(2, "[Cache] Exceed max compile time per code, skip it\n")
return CustomCode(None, False)
if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL:
log_once(
f"[SOT] Current compile time total is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph"
)
log(
2,
"[Cache] Exceed max compile time total, skip it\n",
)
return CustomCode(None, False)
new_custom_code, guard_fn, guard_chain = self.translate(frame, **kwargs)
if guard_fn is not None:
assert guard_chain is not None
Expand Down Expand Up @@ -288,9 +334,9 @@ def inner():
)
return
if result is False:
print(f"[Cache]: missed at {guard_str}")
print(f"[Cache] missed at {guard_str}")
return
print("[Cache]: missed guard not found.")
print("[Cache] missed guard not found.")

return inner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ENV_SOT_ENABLE_GUARD_TREE,
ENV_SOT_ENABLE_STRICT_GUARD_CHECK,
NameGenerator,
SIRToCodeMap,
SotUndefinedVar,
inner_error_default_handler,
is_inplace_api,
Expand Down Expand Up @@ -462,6 +463,7 @@ def compile_graph(self, *ret_vars: VariableBase) -> CompileGraphResult:
OrderedSet(),
OrderedSet(),
)
SIRToCodeMap().register(statement_ir, self.pycode_gen._origin_code)
input_names = statement_ir.inputs
symbolic_inputs = self._find_tensor_inputs(input_names)
compiled_fn = self.sir_ctx.compile_fn(
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
)
log(
3,
f"[FasterGuard]: transform {original_expr_template} to {expr_template}\n",
f"[FasterGuard] transform {original_expr_template} to {expr_template}\n",
)

super().__init__(expr_template, sub_exprs, free_vars)
Expand Down Expand Up @@ -182,7 +182,7 @@ def make_guard(stringified_guards: list[StringifiedExpression]) -> Guard:

guard = eval(guard_expr, free_vars)

log(3, f"[Guard]: {inlined_guard_expr}\n")
log(3, f"[Guard] {inlined_guard_expr}\n")
guard.inlined_expr = inlined_guard_expr
guard.expr = guard_expr

Expand Down Expand Up @@ -231,7 +231,7 @@ def wrapper(self: CheckGuardInputT) -> list[StringifiedExpression]:
def guard_log():
frame_value_tracer = self.tracker.trace_value_from_frame()
print(
f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}"
f"[Guard] guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}"
)

log_do(4, guard_log)
Expand All @@ -253,7 +253,7 @@ def wrapper(
def guard_log():
frame_value_tracer = self.tracker.trace_value_from_frame()
print(
f"[Guard Tree]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}"
f"[Guard Tree] guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}"
)

log_do(4, guard_log)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def step(self, instr: Instruction):
self._current_line = instr.starts_line
if not hasattr(self, instr.opname):
raise FallbackError(f"opcode: {instr.opname} is not supported.")
log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n"
log_message = f"[Translate {self._name}] (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n"
log(3, log_message)
code_file = self.vframe.code.co_filename
code_line = self._current_line
Expand Down Expand Up @@ -2440,7 +2440,7 @@ def create_loop_body():

log(
3,
"[Resumed Function]: break graph in loop create loop body as\n",
"[Resumed Function] break graph in loop create loop body as\n",
)
log_do(3, lambda: dis.dis(loop_body_fn))

Expand Down Expand Up @@ -2685,7 +2685,7 @@ def create_inline_call_fn():

log(
3,
f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n",
f"[Resumed Function] Inline call for loop function {inline_call_fn.__code__.co_name}\n",
)
log_do(3, lambda: dis.dis(inline_call_fn))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,29 +946,6 @@ def gen_return(self):
def gen_get_iter(self):
return self.add_instr("GET_ITER")

def gen_operator_only(self, op_name):
"""
only generator operator instruction, do nothing for
operands.
"""
return self.add_instr(op_name)

def gen_operator(self, op_name):
"""
only generator operator instruction, do nothing for
operands.
"""
return self.add_instr(op_name)

def gen_compare(self, cmp_op):
"""
only generator operator instruction, do nothing for
operands.
"""
if sys.version_info >= (3, 12):
cmp_op <<= 4
return self.add_instr("COMPARE_OP", cmp_op)

def add_instr(self, *args, **kwargs):
instr = gen_instr(*args, **kwargs)
self._instructions.append(instr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1947,18 +1947,18 @@ def keys(self):
def get(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: received {key} to get value."
f"[{self.__class__.__name__}] received {key} to get value."
)
return self.proxy.get(key)

def set(self, key, value):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: received {key} as key."
f"[{self.__class__.__name__}] received {key} as key."
)
if not isinstance(value, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: received {value} to set value."
f"[{self.__class__.__name__}] received {value} to set value."
)
self.proxy.set(key, value)
self.graph.side_effects.record_proxy_variable(self)
Expand Down
Loading
Loading