14
14
from symbolic_pymc .tensorflow .meta import TFlowMetaOp
15
15
16
16
17
+ class DepthExceededException (Exception ):
18
+ pass
19
+
20
+
17
21
class TFlowPrinter (object ):
18
22
"""A printer that indents and keeps track of already printed subgraphs."""
19
23
20
- def __init__ (self , formatter , buffer ):
24
+ def __init__ (self , formatter , buffer , depth_lower_idx = 0 , depth_upper_idx = sys .maxsize ):
25
+ # The buffer to which results are printed
21
26
self .buffer = buffer
27
+ # A function used to pre-process printed results
22
28
self .formatter = formatter
23
- self .indentation = ""
29
+
30
+ self .depth_count = 0
31
+ self .depth_lower_idx , self .depth_upper_idx = depth_lower_idx , depth_upper_idx
32
+
33
+ # This is the current indentation string
34
+ if self .depth_lower_idx > 0 :
35
+ self .indentation = "... "
36
+ else :
37
+ self .indentation = ""
38
+
39
+ # The set of graphs that have already been printed
24
40
self .printed_subgraphs = set ()
25
41
26
42
@contextmanager
27
43
def indented (self , indent ):
28
44
pre_indentation = self .indentation
29
- if isinstance (indent , int ):
30
- self .indentation += " " * indent
31
- else :
45
+
46
+ self .depth_count += 1
47
+
48
+ if self .depth_lower_idx < self .depth_count <= self .depth_upper_idx :
32
49
self .indentation += indent
50
+
33
51
try :
34
52
yield
53
+ except DepthExceededException :
54
+ pass
35
55
finally :
36
56
self .indentation = pre_indentation
57
+ self .depth_count -= 1
37
58
38
59
def format (self , obj ):
39
60
return self .indentation + self .formatter (obj )
40
61
41
- def print (self , obj ):
42
- self .buffer .write (self .format (obj ))
43
- self .buffer .flush ()
44
-
45
- def println (self , obj ):
46
- self .buffer .write (self .format (obj ) + "\n " )
47
- self .buffer .flush ()
62
+ def print (self , obj , suffix = "" ):
63
+ if self .depth_lower_idx <= self .depth_count < self .depth_upper_idx :
64
+ self .buffer .write (self .format (obj ) + suffix )
65
+ self .buffer .flush ()
66
+ elif self .depth_count == self .depth_upper_idx :
67
+ # Only print the cut-off indicator at the first occurrence
68
+ self .buffer .write (self .format (f"...{ suffix } " ))
69
+ self .buffer .flush ()
48
70
71
+ # Prevent the caller from traversing at this level or higher
72
+ raise DepthExceededException ()
49
73
50
- def tf_dprint (obj , printer = None ):
51
- """Print a textual representation of a TF graph.
74
+ def println (self , obj ):
75
+ self .print (obj , suffix = "\n " )
76
+
77
+ def subgraph_add (self , obj ):
78
+ if self .depth_lower_idx <= self .depth_count < self .depth_upper_idx :
79
+ # Only track printed subgraphs when they're actually printed
80
+ self .printed_subgraphs .add (obj )
81
+
82
+ def __repr__ (self ): # pragma: no cover
83
+ return (
84
+ "TFlowPrinter\n "
85
+ f"\t depth_lower_idx={ self .depth_lower_idx } ,\t depth_upper_idx={ self .depth_upper_idx } \n "
86
+ f"\t indentation='{ self .indentation } ',\t depth_count={ self .depth_count } "
87
+ )
88
+
89
+
90
+ def tf_dprint (obj , depth_lower = 0 , depth_upper = 10 , printer = None ):
91
+ """Print a textual representation of a TF graph. The output roughly follows the format of `theano.printing.debugprint`.
92
+
93
+ Parameters
94
+ ----------
95
+ obj : Tensorflow object
96
+ Tensorflow graph object to be represented.
97
+ depth_lower : int
98
+ Used to index specific portions of the graph.
99
+ depth_upper : int
100
+ Used to index specific portions of the graph.
101
+ printer : optional
102
+ Backend used to display the output.
52
103
53
- The output roughly follows the format of `theano.printing.debugprint`.
54
104
"""
105
+
55
106
if isinstance (obj , tf .Tensor ):
56
107
try :
57
108
obj .op
@@ -63,7 +114,7 @@ def tf_dprint(obj, printer=None):
63
114
)
64
115
65
116
if printer is None :
66
- printer = TFlowPrinter (str , sys .stdout )
117
+ printer = TFlowPrinter (str , sys .stdout , depth_lower , depth_upper )
67
118
68
119
_tf_dprint (obj , printer )
69
120
@@ -75,28 +126,22 @@ def _tf_dprint(obj, printer):
75
126
76
127
@_tf_dprint .register (tf .Tensor )
77
128
@_tf_dprint .register (TFlowMetaTensor )
78
- def _ (obj , printer ):
129
+ def _tf_dprint_TFlowMetaTensor (obj , printer ):
79
130
80
131
try :
81
132
shape_str = str (obj .shape .as_list ())
82
133
except (ValueError , AttributeError ):
83
134
shape_str = "Unknown"
84
135
85
136
prefix = f'Tensor({ getattr (obj .op , "type" , obj .op )} ):{ obj .value_index } ,\t dtype={ getattr (obj .dtype , "name" , obj .dtype )} ,\t shape={ shape_str } ,\t "{ obj .name } "'
137
+
86
138
_tf_dprint (prefix , printer )
87
139
88
140
if isvar (obj .op ):
89
141
return
90
142
elif isvar (obj .op .inputs ):
91
143
with printer .indented ("| " ):
92
144
_tf_dprint (f"{ obj .op .inputs } " , printer )
93
- elif len (obj .op .inputs ) > 0 :
94
- with printer .indented ("| " ):
95
- if obj not in printer .printed_subgraphs :
96
- printer .printed_subgraphs .add (obj )
97
- _tf_dprint (obj .op , printer )
98
- else :
99
- _tf_dprint ("..." , printer )
100
145
elif obj .op .type == "Const" :
101
146
with printer .indented ("| " ):
102
147
if isinstance (obj , tf .Tensor ):
@@ -110,10 +155,17 @@ def _(obj, printer):
110
155
_tf_dprint (
111
156
np .array2string (numpy_val , threshold = 20 , prefix = printer .indentation ), printer
112
157
)
158
+ elif len (obj .op .inputs ) > 0 :
159
+ with printer .indented ("| " ):
160
+ if obj in printer .printed_subgraphs :
161
+ _tf_dprint ("..." , printer )
162
+ else :
163
+ printer .subgraph_add (obj )
164
+ _tf_dprint (obj .op , printer )
113
165
114
166
115
167
@_tf_dprint .register (tf .Operation )
116
168
@_tf_dprint .register (TFlowMetaOp )
117
- def _ (obj , printer ):
169
+ def _tf_dprint_TFlowMetaOp (obj , printer ):
118
170
for op_input in obj .inputs :
119
171
_tf_dprint (op_input , printer )
0 commit comments