diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index e221b9cae7abe6..faea899a4b5785 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -112,7 +112,7 @@ if TYPE_CHECKING: from .function_graph import CompileGraphResult, FunctionGraph -SUPPORT_COMPARE_OP = { +COMPARE_OP_NAME_TO_FN = { ">": operator.gt, "<": operator.lt, ">=": operator.ge, @@ -1347,7 +1347,7 @@ def COMPARE_OP(self, instr: Instruction): right, left = self.stack.pop(), self.stack.pop() self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) @@ -1366,7 +1366,7 @@ def IS_OP(self, instr: Instruction): op = "is" if instr.arg == 0 else "is not" self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) @@ -1583,7 +1583,7 @@ def CONTAINS_OP(self, instr: Instruction): op = "in" if instr.arg == 0 else "not in" self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 1d430b5bc320e5..351920707423dc 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -33,7 +33,6 @@ FallbackError, InnerError, ResumeFnNameFactory, - is_clean_code, list_contain_by_id, list_find_index_by_id, no_eval_frame, @@ -516,8 +515,6 @@ def gen_disable_eval_frame(self): """ Generates instructions to disable the evaluation frame. """ - if is_clean_code(): - return self.gen_load_object( paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" ) @@ -529,8 +526,6 @@ def gen_enable_eval_frame(self): """ Generates instructions to enable the evaluation frame. """ - if is_clean_code(): - return self.gen_load_object( paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" ) diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py index 958526346a3b0f..2fdb61b6f3f211 100644 --- a/python/paddle/jit/sot/translate.py +++ b/python/paddle/jit/sot/translate.py @@ -26,7 +26,6 @@ GraphLogger, InfoCollector, StepInfoManager, - StepState, log_do, ) @@ -96,42 +95,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: def callback(frame): return eval_frame_callback(frame, **kwargs) - def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: - assert hasattr( - fn, "__code__" - ), "Target function doesn't have code for simulating." - StepInfoManager().sot_step() - GraphLogger().clear() - InfoCollector().clear_step_info() - paddle.framework.core.set_eval_frame(callback) - try: - outs = fn(*args, **kwargs) - except Exception as e: - raise e - finally: - paddle.framework.core.set_eval_frame(None) - - log_do(1, lambda: GraphLogger().print_info()) - InfoCollector().print_step_report() - return outs - - def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: - outs = fn(*args, **kwargs) - return outs - def impl(*args: P.args, **kwargs: P.kwargs) -> R: with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard(): - state = StepInfoManager().current_state - - if state == StepState.RUN_SOT: - return impl_sot(*args, **kwargs) - elif state == StepState.RUN_DYN: - return impl_dynamic(*args, **kwargs) - elif state == StepState.COLLECT_INFO: - return StepInfoManager().collect_info( - impl_dynamic, impl_sot, *args, **kwargs - ) - else: - raise RuntimeError("Unknown state.") + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + GraphLogger().clear() + InfoCollector().clear_step_info() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + InfoCollector().print_step_report() + return outs return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 8765c81235e46c..3ceccc605ec09c 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -14,8 +14,6 @@ from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 - ENV_CLEAN_CODE, - ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SOT_ALLOW_DYNAMIC_SHAPE, ENV_SOT_ENABLE_FASTER_GUARD, @@ -26,7 +24,6 @@ ENV_SOT_WITH_CONTROL_FLOW, ENV_STRICT_MODE, allow_dynamic_shape_guard, - cost_model_guard, export_guard, faster_guard_guard, guard_tree_guard, @@ -65,7 +62,6 @@ Singleton, SotUndefinedVar, StepInfoManager, - StepState, count_if, current_symbol_registry, execute_time, @@ -78,7 +74,6 @@ in_paddle_module, is_break_graph_api, is_builtin_fn, - is_clean_code, is_comprehensive_name, is_paddle_api, is_strict_mode, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index e84753866d44f4..0af7e46fd86c9d 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -25,11 +25,9 @@ StringListEnvironmentVariable, ) -ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) -ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False) ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable( "SOT_WITH_CONTROL_FLOW", True ) @@ -60,12 +58,6 @@ ) -@contextmanager -def cost_model_guard(value: bool): - with EnvironmentVariableGuard(ENV_COST_MODEL, value): - yield - - @contextmanager def strict_mode_guard(value: bool): with EnvironmentVariableGuard(ENV_STRICT_MODE, value): diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 2c74acedb9ac51..6cbc05582afcc8 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -22,7 +22,6 @@ import weakref from collections import OrderedDict from contextlib import contextmanager -from enum import Enum from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary @@ -32,8 +31,6 @@ from paddle.utils import flatten, map_structure from .envs import ( - ENV_CLEAN_CODE, - ENV_COST_MODEL, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, ) @@ -319,10 +316,6 @@ def is_strict_mode(): return ENV_STRICT_MODE.get() -def is_clean_code() -> bool: - return ENV_CLEAN_CODE.get() - - def list_find_index_by_id(li: list[Any], item: Any) -> int: return [id(it) for it in li].index(id(item)) @@ -409,69 +402,15 @@ def printable(obj): return False -class StepState(Enum): - COLLECT_INFO = 1 - RUN_SOT = 2 - RUN_DYN = 3 - - class StepInfo: - REQUIRED_DYN_INFOS = 10 - REQUIRED_SOT_INFOS = 10 - - USED_DYN_INFOS = 5 - - COLLECT_INFO_MAX_STEP = 50 - CV_BOUNDARY = 0.1 - BACK_TRACE_STEPS = 20 def __init__(self): self.step_count = -1 - self.state = ( - StepState.COLLECT_INFO - if ENV_COST_MODEL.get() - else StepState.RUN_SOT - ) - self.dyn_time_costs = [] - self.avg_dyn_time = 0 - self.sot_time_costs = [] - self.sot_step = -1 - - def add_dynamic_time_info(self, time_cost): - self.dyn_time_costs.append(time_cost) - if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: - self.avg_dyn_time = np.mean( - self.dyn_time_costs[-self.USED_DYN_INFOS :] - ) - - def add_sot_time_info(self, time_cost, current_code): - self.sot_time_costs.append(time_cost) - if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: - avg_sot_time = np.mean(self.sot_time_costs) - log( - 1, - f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", - ) - if avg_sot_time < self.avg_dyn_time: - log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") - self.state = StepState.RUN_SOT - elif ( - self.step_count > self.COLLECT_INFO_MAX_STEP - or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY - ): - log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") - self.state = StepState.RUN_DYN - else: - log(1, f"[Cost Model] Decision delayed: {current_code}\n") - self.sot_time_costs.clear() def need_back_trace(self): return self.step_count < self.BACK_TRACE_STEPS - def need_dynamic_info(self): - return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS - class StepInfoManager(metaclass=Singleton): def __init__(self): @@ -491,34 +430,11 @@ def step_guard(self, code): self.current_step_info = self.step_record[code] self.current_step_info.step_count += 1 - - log( - 2, - f"[Cost Model] New step start, current state is {self.current_state}\n", - ) yield finally: self.current_code = old_code self.current_step_info = old_info - def sot_step(self): - self.current_step_info.sot_step += 1 - - def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): - if self.current_step_info.need_dynamic_info(): - start_time = time.perf_counter() - outs = impl_dynamic(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_dynamic_time_info(time_cost) - else: - start_time = time.perf_counter() - outs = impl_sot(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_sot_time_info( - time_cost, self.current_code - ) - return outs - @property def need_back_trace(self): return self.current_step_info.need_back_trace() @@ -527,10 +443,6 @@ def need_back_trace(self): def current_step(self): return self.current_step_info.step_count - @property - def current_state(self): - return self.current_step_info.state - def clear(self): self.step_record.clear() self.current_code = None diff --git a/test/sot/test_sot_cost_model.py b/test/sot/test_sot_cost_model.py deleted file mode 100644 index eed690a1e77815..00000000000000 --- a/test/sot/test_sot_cost_model.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import unittest - -from test_case_base import TestCaseBase - -import paddle -from paddle.jit.sot import psdb, symbolic_translate -from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard - - -def dyn_fast(x, net, iter_): - for i in iter_: - x = net(x) - return x - - -def sot_fast_with_single_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - return x + 1 - - -def sot_fast_with_multi_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - x = x + 1 - psdb.breakgraph() - x = x + 2 - return x - - -class Net(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.linear = paddle.nn.Linear(10, 10) - - def forward(self, x): - if not psdb.in_sot(): - time.sleep(0.1) - x = x / 3 - x = x + 5 - x = self.linear(x) - return x - - -class TestCostModel(TestCaseBase): - @cost_model_guard(True) - def test_dyn_fast(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(dyn_fast) - for i in range(60): - sot_fn(x, net, iter(range(10))) - - state = StepInfoManager().step_record[dyn_fast.__code__].state - assert state == StepState.RUN_DYN - - @cost_model_guard(True) - def test_sot_fast_with_multi_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(sot_fast_with_multi_graph) - for i in range(30): - sot_fn(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_multi_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_sot_fast_with_single_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - for i in range(30): - symbolic_translate(sot_fast_with_single_graph)(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_single_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_net(self): - x = paddle.rand([10]) - net = Net() - net = paddle.jit.to_static(net, full_graph=False) - for i in range(30): - x = net(x) - - state = StepInfoManager().step_record[Net.forward.__code__].state - assert state == StepState.RUN_SOT - - -if __name__ == "__main__": - unittest.main()