Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 994e37e

Browse files
authored
update call compile fn (#409)
1 parent 4371abe commit 994e37e

File tree

5 files changed

+81
-34
lines changed

5 files changed

+81
-34
lines changed

sot/opcode_translator/executor/function_graph.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def start_compile(self, *ret_vars: VariableBase):
301301
found = False
302302
for variable in self.input_variables:
303303
if (
304-
isinstance(variable, (TensorVariable, PaddleLayerVariable))
304+
isinstance(variable, TensorVariable)
305305
and variable.get_symbol().name == name
306306
):
307307
variable.tracker.gen_instructions(self.pycode_gen)
@@ -426,15 +426,12 @@ def call_layer(
426426
"""
427427

428428
def infer_meta_fn(layer, *metas, **kwmetas):
429-
metas = metas[1:]
430429
metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas)
431430
return metas
432431

433432
def compute_fn(layer, inputs, outputs, stacks):
434-
inputs = (layer.get_symbol(), *inputs)
435-
inputs = inputs[1:]
436433
self.sir_ctx.call_LAYER(
437-
layer.value.__class__.__name__,
434+
layer.value,
438435
inputs=inputs,
439436
outputs=outputs,
440437
stacks=stacks,
@@ -444,7 +441,7 @@ def message_handler(*args, **kwargs):
444441
return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?"
445442

446443
return inner_error_default_handler(self.symbolic_call, message_handler)(
447-
infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs
444+
infer_meta_fn, compute_fn, layer, *args, **kwargs
448445
)
449446

450447
@event_register("symbolic_call", event_level=2)

sot/opcode_translator/executor/variables/callable.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
import paddle
1010

1111
from .... import psdb
12-
from ....symbolic.statement_ir import Symbol
1312
from ....utils import (
1413
EventGuard,
15-
NameGenerator,
1614
is_break_graph_api,
1715
is_break_graph_tensor_methods,
1816
is_builtin_fn,
@@ -503,18 +501,13 @@ class PaddleLayerVariable(LayerVariable):
503501
tracker(Tracker): The Tracker object that tracks the information of this variable.
504502
"""
505503

506-
layer_name_generator = NameGenerator("layer_")
507-
508504
def __init__(
509505
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker
510506
):
511507
super().__init__(layer, graph, tracker)
512-
self.name = self.layer_name_generator.next()
513-
514-
def get_symbol(self) -> Symbol:
515-
return Symbol(self.name)
516508

517509
def call_function(self, /, *args, **kwargs):
510+
self.graph.add_global_guarded_variable(self)
518511
return self.graph.call_layer(self, *args, **kwargs)
519512

520513
def make_stringify_guard(self) -> list[StringifyExpression]:

sot/symbolic/interpreter.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,25 @@ def _set(v, s):
120120
return replace_symbol(SIR.outputs, state)
121121

122122
def call(self, stmt: Statement, inputs):
123-
SIR = self.get_sir(stmt.name)
123+
SIR = self.get_sir(stmt.sir_name)
124124
state = prepare_state(SIR, inputs)
125-
return self.run_sir(stmt.name, state)
125+
return self.run_sir(stmt.sir_name, state)
126126

127127
def api(self, stmt, inputs):
128128
args, kwargs = inputs
129-
return stmt.name(*args, **kwargs)
129+
return stmt.api(*args, **kwargs)
130130

131131
def method(self, stmt, inputs):
132132
args, kwargs = inputs
133133
var = args[0]
134-
return getattr(var, stmt.name)(*args[1:], **kwargs)
134+
return getattr(var, stmt.method)(*args[1:], **kwargs)
135135

136136
def layer(self, stmt, inputs):
137137
args, kwargs = inputs
138-
layer, args = args[0], args[1:]
138+
layer = stmt.layer()
139+
assert layer is not None, "SIR bound layer is None."
139140
return layer(*args, **kwargs)
140141

141-
def delete(self, stmt, inputs):
142-
pass
143-
144142

145143
def compile_sir(context: SymbolicTraceContext, name: str):
146144
"""

sot/symbolic/statement_ir.py

+57-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
"""
66
from __future__ import annotations
77

8+
import weakref
9+
from typing import Callable
10+
11+
import paddle
812
from paddle.utils import is_sequence, map_structure
913

1014
from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend
@@ -69,22 +73,69 @@ def to_string(inps):
6973
inps = (x.__str__() for x in inps)
7074
return ", ".join(inps)
7175

72-
name = (
73-
self.name
74-
if isinstance(self.name, str)
75-
else "paddle." + self.name.__name__
76-
)
7776
return "{} || {} = {} ({}) ".format(
7877
self.type + " " * (10 - len(self.type)),
7978
to_string(self.outputs),
80-
name,
79+
self.name,
8180
to_string(self.inputs),
8281
)
8382

8483
def __repr__(self):
8584
return self.__str__()
8685

8786

87+
class CallStatement(Statement):
88+
def __init__(
89+
self,
90+
name: str,
91+
inputs: list[Symbol],
92+
outputs: list[Symbol],
93+
stacks: list[str],
94+
):
95+
super().__init__("call", name, inputs, outputs, stacks)
96+
self.sir_name = name
97+
98+
99+
class ApiStatement(Statement):
100+
def __init__(
101+
self,
102+
api: Callable,
103+
inputs: list[Symbol],
104+
outputs: list[Symbol],
105+
stacks: list[str],
106+
):
107+
super().__init__(
108+
"api", "paddle." + api.__name__, inputs, outputs, stacks
109+
)
110+
self.api = api
111+
112+
113+
class MethodStatement(Statement):
114+
def __init__(
115+
self,
116+
name: str,
117+
inputs: list[Symbol],
118+
outputs: list[Symbol],
119+
stacks: list[str],
120+
):
121+
super().__init__("method", name, inputs, outputs, stacks)
122+
self.method = name
123+
124+
125+
class LayerStatement(Statement):
126+
def __init__(
127+
self,
128+
layer: paddle.nn.Layer,
129+
inputs: list[Symbol],
130+
outputs: list[Symbol],
131+
stacks: list[str],
132+
):
133+
super().__init__(
134+
"layer", layer.__class__.__name__, inputs, outputs, stacks
135+
)
136+
self.layer = weakref.ref(layer)
137+
138+
88139
class StatementIR:
89140
"""
90141
StatementIR is the carrier that records the code for building the neural network model.It is

sot/symbolic/symbolic_context.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
from ..utils import event_register, log
44
from .compile_cache import CompileSIRCache
5-
from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol
5+
from .statement_ir import (
6+
ApiStatement,
7+
CallStatement,
8+
LayerStatement,
9+
MethodStatement,
10+
StatementIR,
11+
StatementIRFactory,
12+
Symbol,
13+
)
614

715

816
class SymbolicTraceContext:
@@ -41,7 +49,7 @@ def call_SIR(self, sirname, inputs, outputs, stacks):
4149
Call a SIR, which is a subgraph.
4250
"""
4351

44-
stmt = Statement("call", sirname, inputs, outputs, stacks)
52+
stmt = CallStatement(sirname, inputs, outputs, stacks)
4553
self.TOS.add_statement(stmt)
4654

4755
@event_register("call_API", event_level=2)
@@ -51,7 +59,7 @@ def call_API(self, api, inputs, outputs, stacks):
5159
"""
5260

5361
assert callable(api), "call_API must receive a paddle api."
54-
stmt = Statement("api", api, inputs, outputs, stacks)
62+
stmt = ApiStatement(api, inputs, outputs, stacks)
5563
self.TOS.add_statement(stmt)
5664

5765
@event_register("call_METHOD", event_level=2)
@@ -65,15 +73,15 @@ def call_METHOD(self, method_name, inputs, outputs, stacks):
6573
assert isinstance(
6674
inputs[0][0], Symbol
6775
), "call_METHOD must first augument must be Symbol Variable."
68-
stmt = Statement("method", method_name, inputs, outputs, stacks)
76+
stmt = MethodStatement(method_name, inputs, outputs, stacks)
6977
self.TOS.add_statement(stmt)
7078

7179
@event_register("call_LAYER", event_level=2)
72-
def call_LAYER(self, layer_name, inputs, outputs, stacks):
80+
def call_LAYER(self, layer, inputs, outputs, stacks):
7381
"""
7482
Call a layer of a api.
7583
"""
76-
stmt = Statement("layer", layer_name, inputs, outputs, stacks)
84+
stmt = LayerStatement(layer, inputs, outputs, stacks)
7785
self.TOS.add_statement(stmt)
7886

7987
def get_sir(self, name: str):

0 commit comments

Comments
 (0)