Skip to content

Commit c849c7e

Browse files
authored
[SOT][DynamicShape] Add basic constraint mechanism (#72250)
1 parent 198114e commit c849c7e

16 files changed

+665
-81
lines changed

python/paddle/jit/sot/infer_meta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from paddle.static import InputSpec
3939
from paddle.utils import flatten, is_sequence
4040

41-
from .symbolic_shape import SymbolicInt
41+
from .symbolic_shape.symbolic_value import SymbolicInt
4242
from .utils import (
4343
Cache,
4444
Singleton,

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ...profiler import EventGuard, event_register
4040
from ...symbolic.statement_ir import Reference, StatementIR, Symbol
4141
from ...symbolic.symbolic_context import SymbolicTraceContext
42-
from ...symbolic_shape import SYMBOLIC_BINARY_OPS, SYMBOLIC_UNARY_OPS
42+
from ...symbolic_shape.operators import SYMBOLIC_BINARY_OPS, SYMBOLIC_UNARY_OPS
4343
from ...utils import (
4444
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
4545
ENV_SOT_ENABLE_GUARD_TREE,

python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131

3232
from ...profiler import EventGuard
3333
from ...psdb import NO_BREAKGRAPH_CODES
34+
from ...symbolic_shape.constraints import LogicalNotConstraintNode
3435
from ...utils import (
3536
ENV_MIN_GRAPH_SIZE,
3637
ENV_SOT_FORCE_FALLBACK_SIR_IDS,
3738
BreakGraphError,
39+
DataDependencyDynamicShapeBreak,
3840
FallbackError,
3941
InnerError,
4042
SotUndefinedVar,
@@ -224,12 +226,23 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
224226
)(res)
225227

226228
assert isinstance(res, (ConstantVariable, SymbolicVariable))
227-
# NOTE(SigureMo): force to constant to trigger fallback to static dim
228-
# to align with old behavior. In next PR we will support guard value
229-
# with constraint.
230229
if isinstance(res, SymbolicVariable):
231-
res = res.to_constant()
232-
is_jump = res.get_py_value()
230+
constraint_node, symbolic_vars = res.create_constraint_tree()
231+
if not all(
232+
var.value.is_backed() for var in symbolic_vars.values()
233+
):
234+
raise BreakGraphError(
235+
DataDependencyDynamicShapeBreak(
236+
f"Symbolic variable {symbolic_vars} is not backed."
237+
)
238+
)
239+
is_jump = res.get_example_value()
240+
if not is_jump:
241+
constraint_node = LogicalNotConstraintNode(constraint_node)
242+
for var in symbolic_vars.values():
243+
var.add_constraint((constraint_node, symbolic_vars))
244+
else:
245+
is_jump = res.get_py_value()
233246
assert isinstance(is_jump, bool)
234247
if is_jump:
235248
assert instr.jump_to is not None

python/paddle/jit/sot/opcode_translator/executor/side_effects.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def restore_state(self, state: SideEffectsState):
7979
assert len(self.data_id_to_proxy.values()) == len(
8080
state.proxy_versions
8181
), "proxy_versions length not match"
82-
assert len(self.mutable_variables) == len(
83-
state.mutable_attrs
84-
), "mutable_attrs length not match"
82+
assert sum(
83+
len(var.mutable_attrs) for var in self.mutable_variables
84+
) == len(state.mutable_attrs), "mutable_attrs length not match"
8585

8686
for proxy, version in zip(
8787
self.data_id_to_proxy.values(), state.proxy_versions

python/paddle/jit/sot/opcode_translator/executor/tracker.py

-7
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,9 @@ def __init__(self, inputs: Sequence[VariableBase], op: UnaryOp | BinaryOp):
137137
self.op = op
138138

139139
def gen_instructions(self, codegen: PyCodeGen):
140-
# TODO(zrr1999): implemented in #68555
141140
raise InnerError("SymbolicOperationTracker has no instructions")
142141

143142
def trace_value_from_frame(self):
144-
# TODO(zrr1999): implemented in #68555
145143
raise InnerError(
146144
"SymbolicOperationTracker can't trace value from frame"
147145
)
@@ -150,11 +148,6 @@ def __repr__(self) -> str:
150148
return f"SymbolicOperationTracker(num_inputs={len(self.inputs)})"
151149

152150
def is_traceable(self):
153-
# TODO(zrr1999): implemented in #68555
154-
return False
155-
156-
def need_guard(self) -> bool:
157-
# TODO(zrr1999): implemented in #68555
158151
return False
159152

160153

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import paddle
2626

27-
from ...symbolic_shape import (
27+
from ...symbolic_shape.operators import (
2828
SYMBOLIC_BINARY_OPS,
2929
SYMBOLIC_UNARY_OPS,
3030
symbolic_not,

python/paddle/jit/sot/opcode_translator/executor/variables/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _find_var(key: str = "default") -> VariableBase | None:
275275
return var
276276

277277

278-
def infer_debug_name_from_tracker(tracker: Tracker) -> str:
278+
def infer_debug_name_from_tracker(tracker: Tracker) -> str | None:
279279
res = None
280280
if isinstance(tracker, (LocalTracker, GlobalTracker, BuiltinTracker)):
281281
res = f"{tracker.name}"

0 commit comments

Comments
 (0)