Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

update file struct #413

Merged
merged 8 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sot/opcode_translator/custom_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import types
from typing import NamedTuple


class CustomCode(NamedTuple):
code: types.CodeType | None
disable_eval_frame: bool
216 changes: 216 additions & 0 deletions sot/opcode_translator/executor/executor_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from __future__ import annotations

import traceback
import types
from typing import List, Tuple

from ...profiler import EventGuard, event_register
from ...psdb import NO_FALLBACK_CODES
from ...utils import (
BreakGraphError,
FallbackError,
InnerError,
Singleton,
is_strict_mode,
log,
log_do,
)
from ..custom_code import CustomCode
from .guard import Guard
from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase
from .pycode_generator import PyCodeGen

GuardedFunction = Tuple[CustomCode, Guard]
GuardedFunctions = List[GuardedFunction]

dummy_guard: Guard = lambda frame: True
dummy_guard.expr = "lambda frame: True"
dummy_guard.lambda_expr = "lambda frame: True"


@Singleton
class OpcodeExecutorCache:
"""
A singleton class that implements a cache for translated instructions.
This cache is used to store previously translated instructions along with their corresponding guard functions.

Attributes:
cache (dict): A dictionary that maps code objects to tuples of a cache getter function and a list of guarded functions.
translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits.
"""

MAX_CACHE_SIZE = 20
cache: dict[types.CodeType, GuardedFunctions]
translate_count: int

def __init__(self):
self.cache = {}
self.translate_count = 0

def clear(self):
"""
Clears the cache and resets the translate count.
"""
self.cache.clear()
self.translate_count = 0

def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
code: types.CodeType = frame.f_code
if code not in self.cache:
log(2, f"[Cache]: Firstly call {code}\n")
new_custom_code, guard_fn = self.translate(frame, **kwargs)
self.cache[code] = [(new_custom_code, guard_fn)]
return new_custom_code
guarded_fns = self.cache[code]
return self.lookup(frame, guarded_fns, **kwargs)

@event_register("lookup")
def lookup(
self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs
) -> CustomCode:
"""
Looks up the cache for a matching code object and returns a custom code object if a matching guard function is found, otherwise None.

Args:
frame (types.FrameType): The frame whose code object needs to be looked up in the cache.
guarded_fns (GuardedFunctions): The list of guarded functions associated with the code object.

Returns:
CustomCode | None: The custom code object if a matching guard function is found, otherwise None.
"""

if len(guarded_fns) >= self.MAX_CACHE_SIZE:
log(2, "[Cache]: Exceed max cache size, skip it\n")
return CustomCode(None, False)

for custom_code, guard_fn in guarded_fns:
try:
with EventGuard("try guard"):
guard_result = guard_fn(frame)
if guard_result:
log(
2,
f"[Cache]: Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n",
)
return custom_code
else:
log_do(
4,
self.analyse_guard_global_object(guard_fn),
)
log(
2,
f"[Cache]: Cache miss, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n",
)
log_do(
2,
self.analyse_guard_error(guard_fn, frame),
)
except Exception as e:
log(2, f"[Cache]: Guard function error: {e}\n")
continue

log(2, "[Cache]: all guards missed\n")
new_custom_code, guard_fn = self.translate(frame, **kwargs)
guarded_fns.append((new_custom_code, guard_fn))
return new_custom_code

def translate(
self, frame: types.FrameType, **kwargs
) -> tuple[CustomCode, Guard]:
"""
Translates the given frame's code object and returns the cache getter function and a guarded function for the translated code object.

Args:
frame (types.FrameType): The frame whose code object needs to be translated.

Returns:
tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object.
"""
code: types.CodeType = frame.f_code
self.translate_count += 1
custom_new_code, guard_fn = start_translate(frame, **kwargs)
return custom_new_code, guard_fn

def analyse_guard_global_object(self, guard_fn):
def inner():
for key in guard_fn.__globals__.keys():
if key.startswith("__object"):
print(
f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}",
)

return inner

def analyse_guard_error(self, guard_fn, frame):
def inner():
guard_expr = guard_fn.lambda_expr
lambda_head = "lambda frame: "
guard_expr = guard_expr.replace(lambda_head, "")
guards = guard_expr.split(" and ")
for guard_str in guards:
guard = eval(lambda_head + guard_str, guard_fn.__globals__)
result = False
try:
result = guard(frame)
except Exception as e:
print(
f"[Cache]: skip checking {guard_str}\n because error occured {e}"
)
if result is False:
print(f"[Cache]: missed at {guard_str}")
return
print("[Cache]: missed guard not found.")

return inner


def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
"""
Starts the translation process for the given frame and returns the translated code object and its guard function, or None if translation fails.

Args:
frame: The frame to be translated.

Returns:
GuardedFunction | None: The translated code object and its guard function, or None if translation fails.
"""
simulator = OpcodeExecutor(frame, **kwargs)
try:
new_custom_code, guard_fn = simulator.transform()
return new_custom_code, guard_fn
# TODO(zrr1999): InnerError maybe place before (FallbackError, BreakGraphError)
# TODO(0x45f): handle BreakGraphError to trigger fallback
except BreakGraphError as e:
raise RuntimeError(
f"Found BreakGraphError raised, it should not be catch at start_translate!\n{e}"
)
except FallbackError as e:
if simulator._code in NO_FALLBACK_CODES:
raise InnerError(
f"{simulator._code.co_name} should not fallback, but got '{e}'"
)
# if disable_eval_frame is True, it means we want fallback to speedup rather than error occured
if is_strict_mode() and e.disable_eval_frame is False:
raise
log(
2,
f"Unsupport Frame is {frame.f_code}, error message is: \n"
+ "".join(traceback.format_exception(type(e), e, e.__traceback__)),
)

# NOTE: If resume fn need fallback, we should replace NullVariable using NULL otherwise will fail to run
py_codegen = PyCodeGen(frame)
new_code = py_codegen.replace_null_variable()
# simulation not complete, not sure whether this code has sir, set disable_eval_frame = False
guard_fn = (
dummy_guard if e.disable_eval_frame is False else simulator.guard_fn
)
return (
CustomCode(new_code, e.disable_eval_frame),
guard_fn,
)
except Exception as e:
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e
finally:
simulator.cleanup()
33 changes: 12 additions & 21 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from typing import Any, Callable

from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo
from ...profiler import EventGuard, event_register
from ...symbolic.statement_ir import Symbol
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...utils import (
EventGuard,
NameGenerator,
OrderedSet,
event_register,
inner_error_default_handler,
is_inplace_api,
is_paddle_api,
Expand Down Expand Up @@ -162,19 +161,16 @@ def save_memo(self) -> FunctionGraph.Memo:
NOTE:
Why don't use __deepcopy__, because memo is not a deepcopy, i.e inner_out is only a shallow copy, SIR is a deepcopy.
"""
with EventGuard(
f"Save SIR Checkpoint: len({len(self.sir_ctx.TOS)})", event_level=2
):
saved_stmt_ir = deepcopy(self.sir_ctx.TOS)
return FunctionGraph.Memo(
inner_out=set(self.inner_out),
input_variables=list(self.input_variables),
stmt_ir=saved_stmt_ir,
global_guards=OrderedSet(self._global_guarded_variables),
side_effects_state=self.side_effects.get_state(),
print_variables=list(self._print_variables),
inplace_tensors=OrderedSet(self._inplace_tensors),
)
saved_stmt_ir = deepcopy(self.sir_ctx.TOS)
return FunctionGraph.Memo(
inner_out=set(self.inner_out),
input_variables=list(self.input_variables),
stmt_ir=saved_stmt_ir,
global_guards=OrderedSet(self._global_guarded_variables),
side_effects_state=self.side_effects.get_state(),
print_variables=list(self._print_variables),
inplace_tensors=OrderedSet(self._inplace_tensors),
)

def restore_memo(self, memo: FunctionGraph.Memo):
"""
Expand Down Expand Up @@ -333,7 +329,6 @@ def start_compile(self, *ret_vars: VariableBase):

view_tracker(list(ret_vars), tracker_output_path, format="png")

@event_register("call_paddle_api", event_level=2)
def call_paddle_api(
self,
func: Callable[..., Any],
Expand All @@ -359,7 +354,6 @@ def message_handler(*args, **kwargs):
InferMetaCache(), self.sir_ctx.call_API, func, *args, **kwargs
)

@event_register("call_tensor_method", event_level=2)
def call_tensor_method(
self, method_name: str, *args: VariableBase, **kwargs
):
Expand Down Expand Up @@ -411,7 +405,6 @@ def get_opcode_executor_stack():
stack.append(f' {code_line}')
return stack

@event_register("call_layer", event_level=2)
def call_layer(
self,
layer: PaddleLayerVariable,
Expand Down Expand Up @@ -444,7 +437,6 @@ def message_handler(*args, **kwargs):
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

@event_register("symbolic_call", event_level=2)
def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
"""
Using infer_meta_fn and compute_fn convert func to symbolic function.
Expand All @@ -459,8 +451,7 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
metas = convert_to_meta(args)
kwmetas = convert_to_meta(kwargs)

with EventGuard("infer_meta"):
out_metas = infer_meta_fn(func, *metas, **kwmetas)
out_metas = infer_meta_fn(func, *metas, **kwmetas)
inputs_symbols = (
convert_to_symbol(args),
convert_to_symbol(kwargs),
Expand Down
11 changes: 3 additions & 8 deletions sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
import weakref
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from ...utils import (
EventGuard,
InnerError,
current_tmp_name_records,
log,
log_do,
)
from ...profiler import EventGuard
from ...utils import InnerError, current_tmp_name_records, log, log_do

Guard = Callable[[types.FrameType], bool]

Expand Down Expand Up @@ -71,7 +66,7 @@ def make_guard(stringify_guards: list[StringifyExpression]) -> Guard:
Args:
stringify_guards: a list of StringifyExpression.
"""
with EventGuard(f"make_guard: ({len(stringify_guards)})"):
with EventGuard("make_guard"):
num_guards = len(stringify_guards)
if not num_guards:
guard = lambda frame: True
Expand Down
Loading