|
31 | 31 |
|
32 | 32 | from ...profiler import EventGuard
|
33 | 33 | from ...psdb import NO_BREAKGRAPH_CODES
|
| 34 | +from ...symbolic_shape.constraints import LogicalNotConstraintNode |
34 | 35 | from ...utils import (
|
35 | 36 | ENV_MIN_GRAPH_SIZE,
|
36 | 37 | ENV_SOT_FORCE_FALLBACK_SIR_IDS,
|
37 | 38 | BreakGraphError,
|
| 39 | + DataDependencyDynamicShapeBreak, |
38 | 40 | FallbackError,
|
39 | 41 | InnerError,
|
40 | 42 | SotUndefinedVar,
|
@@ -224,12 +226,23 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
|
224 | 226 | )(res)
|
225 | 227 |
|
226 | 228 | assert isinstance(res, (ConstantVariable, SymbolicVariable))
|
227 |
| - # NOTE(SigureMo): force to constant to trigger fallback to static dim |
228 |
| - # to align with old behavior. In next PR we will support guard value |
229 |
| - # with constraint. |
230 | 229 | if isinstance(res, SymbolicVariable):
|
231 |
| - res = res.to_constant() |
232 |
| - is_jump = res.get_py_value() |
| 230 | + constraint_node, symbolic_vars = res.create_constraint_tree() |
| 231 | + if not all( |
| 232 | + var.value.is_backed() for var in symbolic_vars.values() |
| 233 | + ): |
| 234 | + raise BreakGraphError( |
| 235 | + DataDependencyDynamicShapeBreak( |
| 236 | + f"Symbolic variable {symbolic_vars} is not backed." |
| 237 | + ) |
| 238 | + ) |
| 239 | + is_jump = res.get_example_value() |
| 240 | + if not is_jump: |
| 241 | + constraint_node = LogicalNotConstraintNode(constraint_node) |
| 242 | + for var in symbolic_vars.values(): |
| 243 | + var.add_constraint((constraint_node, symbolic_vars)) |
| 244 | + else: |
| 245 | + is_jump = res.get_py_value() |
233 | 246 | assert isinstance(is_jump, bool)
|
234 | 247 | if is_jump:
|
235 | 248 | assert instr.jump_to is not None
|
|
0 commit comments