Skip to content

Commit 51fa137

Browse files
Merge pull request #60 from josephwillard/issue_50
Add ability to index graph depth.
2 parents 968fc01 + c63a0ad commit 51fa137

File tree

2 files changed

+144
-25
lines changed

2 files changed

+144
-25
lines changed

symbolic_pymc/tensorflow/printing.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,95 @@
1414
from symbolic_pymc.tensorflow.meta import TFlowMetaOp
1515

1616

17+
class DepthExceededException(Exception):
18+
pass
19+
20+
1721
class TFlowPrinter(object):
1822
"""A printer that indents and keeps track of already printed subgraphs."""
1923

20-
def __init__(self, formatter, buffer):
24+
def __init__(self, formatter, buffer, depth_lower_idx=0, depth_upper_idx=sys.maxsize):
25+
# The buffer to which results are printed
2126
self.buffer = buffer
27+
# A function used to pre-process printed results
2228
self.formatter = formatter
23-
self.indentation = ""
29+
30+
self.depth_count = 0
31+
self.depth_lower_idx, self.depth_upper_idx = depth_lower_idx, depth_upper_idx
32+
33+
# This is the current indentation string
34+
if self.depth_lower_idx > 0:
35+
self.indentation = "... "
36+
else:
37+
self.indentation = ""
38+
39+
# The set of graphs that have already been printed
2440
self.printed_subgraphs = set()
2541

2642
@contextmanager
2743
def indented(self, indent):
2844
pre_indentation = self.indentation
29-
if isinstance(indent, int):
30-
self.indentation += " " * indent
31-
else:
45+
46+
self.depth_count += 1
47+
48+
if self.depth_lower_idx < self.depth_count <= self.depth_upper_idx:
3249
self.indentation += indent
50+
3351
try:
3452
yield
53+
except DepthExceededException:
54+
pass
3555
finally:
3656
self.indentation = pre_indentation
57+
self.depth_count -= 1
3758

3859
def format(self, obj):
3960
return self.indentation + self.formatter(obj)
4061

41-
def print(self, obj):
42-
self.buffer.write(self.format(obj))
43-
self.buffer.flush()
44-
45-
def println(self, obj):
46-
self.buffer.write(self.format(obj) + "\n")
47-
self.buffer.flush()
62+
def print(self, obj, suffix=""):
63+
if self.depth_lower_idx <= self.depth_count < self.depth_upper_idx:
64+
self.buffer.write(self.format(obj) + suffix)
65+
self.buffer.flush()
66+
elif self.depth_count == self.depth_upper_idx:
67+
# Only print the cut-off indicator at the first occurrence
68+
self.buffer.write(self.format(f"...{suffix}"))
69+
self.buffer.flush()
4870

71+
# Prevent the caller from traversing at this level or higher
72+
raise DepthExceededException()
4973

50-
def tf_dprint(obj, printer=None):
51-
"""Print a textual representation of a TF graph.
74+
def println(self, obj):
75+
self.print(obj, suffix="\n")
76+
77+
def subgraph_add(self, obj):
78+
if self.depth_lower_idx <= self.depth_count < self.depth_upper_idx:
79+
# Only track printed subgraphs when they're actually printed
80+
self.printed_subgraphs.add(obj)
81+
82+
def __repr__(self): # pragma: no cover
83+
return (
84+
"TFlowPrinter\n"
85+
f"\tdepth_lower_idx={self.depth_lower_idx},\tdepth_upper_idx={self.depth_upper_idx}\n"
86+
f"\tindentation='{self.indentation}',\tdepth_count={self.depth_count}"
87+
)
88+
89+
90+
def tf_dprint(obj, depth_lower=0, depth_upper=10, printer=None):
91+
"""Print a textual representation of a TF graph. The output roughly follows the format of `theano.printing.debugprint`.
92+
93+
Parameters
94+
----------
95+
obj : Tensorflow object
96+
Tensorflow graph object to be represented.
97+
depth_lower : int
98+
Used to index specific portions of the graph.
99+
depth_upper : int
100+
Used to index specific portions of the graph.
101+
printer : optional
102+
Backend used to display the output.
52103
53-
The output roughly follows the format of `theano.printing.debugprint`.
54104
"""
105+
55106
if isinstance(obj, tf.Tensor):
56107
try:
57108
obj.op
@@ -63,7 +114,7 @@ def tf_dprint(obj, printer=None):
63114
)
64115

65116
if printer is None:
66-
printer = TFlowPrinter(str, sys.stdout)
117+
printer = TFlowPrinter(str, sys.stdout, depth_lower, depth_upper)
67118

68119
_tf_dprint(obj, printer)
69120

@@ -75,28 +126,22 @@ def _tf_dprint(obj, printer):
75126

76127
@_tf_dprint.register(tf.Tensor)
77128
@_tf_dprint.register(TFlowMetaTensor)
78-
def _(obj, printer):
129+
def _tf_dprint_TFlowMetaTensor(obj, printer):
79130

80131
try:
81132
shape_str = str(obj.shape.as_list())
82133
except (ValueError, AttributeError):
83134
shape_str = "Unknown"
84135

85136
prefix = f'Tensor({getattr(obj.op, "type", obj.op)}):{obj.value_index},\tdtype={getattr(obj.dtype, "name", obj.dtype)},\tshape={shape_str},\t"{obj.name}"'
137+
86138
_tf_dprint(prefix, printer)
87139

88140
if isvar(obj.op):
89141
return
90142
elif isvar(obj.op.inputs):
91143
with printer.indented("| "):
92144
_tf_dprint(f"{obj.op.inputs}", printer)
93-
elif len(obj.op.inputs) > 0:
94-
with printer.indented("| "):
95-
if obj not in printer.printed_subgraphs:
96-
printer.printed_subgraphs.add(obj)
97-
_tf_dprint(obj.op, printer)
98-
else:
99-
_tf_dprint("...", printer)
100145
elif obj.op.type == "Const":
101146
with printer.indented("| "):
102147
if isinstance(obj, tf.Tensor):
@@ -110,10 +155,17 @@ def _(obj, printer):
110155
_tf_dprint(
111156
np.array2string(numpy_val, threshold=20, prefix=printer.indentation), printer
112157
)
158+
elif len(obj.op.inputs) > 0:
159+
with printer.indented("| "):
160+
if obj in printer.printed_subgraphs:
161+
_tf_dprint("...", printer)
162+
else:
163+
printer.subgraph_add(obj)
164+
_tf_dprint(obj.op, printer)
113165

114166

115167
@_tf_dprint.register(tf.Operation)
116168
@_tf_dprint.register(TFlowMetaOp)
117-
def _(obj, printer):
169+
def _tf_dprint_TFlowMetaOp(obj, printer):
118170
for op_input in obj.inputs:
119171
_tf_dprint(op_input, printer)

tests/tensorflow/test_printing.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,70 @@ def test_numpy():
144144
)
145145

146146
assert std_out.getvalue() == expected_out.lstrip()
147+
148+
149+
@run_in_graph_mode
150+
def test_depth_indexing():
151+
"""Make sure graph indexing functions as expected."""
152+
153+
A = tf.compat.v1.placeholder("float", name="A", shape=tf.TensorShape([None, None]))
154+
x = tf.compat.v1.placeholder("float", name="x", shape=tf.TensorShape([None, 1]))
155+
y = tf.multiply(1.0, x, name="y")
156+
157+
z = tf.matmul(A, tf.add(y, y, name="x_p_y"), name="A_dot")
158+
159+
std_out = io.StringIO()
160+
with redirect_stdout(std_out):
161+
tf_dprint(z, depth_upper=3)
162+
163+
expected_out = textwrap.dedent(
164+
"""
165+
Tensor(MatMul):0,\tdtype=float32,\tshape=[None, 1],\t"A_dot:0"
166+
| Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, None],\t"A:0"
167+
| Tensor(Add):0,\tdtype=float32,\tshape=[None, 1],\t"x_p_y:0"
168+
| | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
169+
| | | ...
170+
| | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
171+
| | | ...
172+
"""
173+
)
174+
175+
assert std_out.getvalue() == expected_out.lstrip()
176+
177+
std_out = io.StringIO()
178+
with redirect_stdout(std_out):
179+
tf_dprint(z, depth_lower=1)
180+
181+
expected_out = textwrap.dedent(
182+
"""
183+
... Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, None],\t"A:0"
184+
... Tensor(Add):0,\tdtype=float32,\tshape=[None, 1],\t"x_p_y:0"
185+
... | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
186+
... | | Tensor(Const):0,\tdtype=float32,\tshape=[],\t"y/x:0"
187+
... | | | 1.
188+
... | | Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, 1],\t"x:0"
189+
... | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
190+
... | | ...
191+
"""
192+
)
193+
194+
assert std_out.getvalue() == expected_out.lstrip()
195+
196+
std_out = io.StringIO()
197+
with redirect_stdout(std_out):
198+
tf_dprint(z, depth_lower=1, depth_upper=4)
199+
200+
expected_out = textwrap.dedent(
201+
"""
202+
... Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, None],\t"A:0"
203+
... Tensor(Add):0,\tdtype=float32,\tshape=[None, 1],\t"x_p_y:0"
204+
... | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
205+
... | | Tensor(Const):0,\tdtype=float32,\tshape=[],\t"y/x:0"
206+
... | | | ...
207+
... | | Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, 1],\t"x:0"
208+
... | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
209+
... | | ...
210+
"""
211+
)
212+
213+
assert std_out.getvalue() == expected_out.lstrip()

0 commit comments

Comments
 (0)