13
13
# limitations under the License.
14
14
15
15
import copy
16
+ import logging
16
17
17
18
import paddle
19
+ from paddle .fluid .log_helper import get_logger
20
+
21
+ _logger = get_logger (
22
+ __name__ , logging .INFO , fmt = '%(asctime)s-%(levelname)s: %(message)s'
23
+ )
18
24
19
25
20
26
class OperatorStatsUnit :
@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input):
76
82
var = block ._var_recursive (var_name )
77
83
return var .dtype
78
84
except :
79
- print (
85
+ _logger . warning (
80
86
"Operator < {} > gets {} < {} : {} > error!" .format (
81
87
op .type , "input" if is_input else "output" , arg_name , var_name
82
88
)
@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block):
99
105
if _is_floating_point (compute_dtype ) and _is_floating_point (
100
106
var_dtype
101
107
):
102
- print (
108
+ _logger . warning (
103
109
"Operator < {} > has different input data types, input_names = {}, output_names = {}." .format (
104
110
op .type , op .input_names , op .output_names
105
111
)
@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block):
125
131
if _is_floating_point (compute_dtype ) and _is_floating_point (
126
132
var_dtype
127
133
):
128
- print (
134
+ _logger . warning (
129
135
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}." .format (
130
136
op .type , op .input_names , op .output_names
131
137
)
@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list):
145
151
146
152
147
153
def _get_op_stats_list (program ):
154
+ def _is_special_ops_with_input_x (op_type ):
155
+ # operators have input X and have inputs different dtypes.
156
+ special_op_list = ['cast' , 'batch_norm' , 'instance_norm' , 'layer_norm' ]
157
+ if op_type in special_op_list :
158
+ return True
159
+ if op_type .replace ("_grad" , "" ) in special_op_list :
160
+ return True
161
+ return False
162
+
148
163
op_stats_list = []
149
164
for block in program .blocks :
150
165
block_op_stats_dict = {}
@@ -161,13 +176,7 @@ def _get_op_stats_list(program):
161
176
'create_double_buffer_reader' ,
162
177
]:
163
178
compute_dtype = None
164
- elif op .type in [
165
- 'cast' ,
166
- 'layer_norm' ,
167
- 'layer_norm_grad' ,
168
- 'batch_norm' ,
169
- 'batch_norm_grad' ,
170
- ]:
179
+ elif _is_special_ops_with_input_x (op .type ):
171
180
# Not check the input and output dtype difference for this operators.
172
181
compute_dtype = _get_var_dtype_from_block (block , op , 'X' , True )
173
182
elif "Param" in op .input_names :
@@ -183,6 +192,78 @@ def _get_op_stats_list(program):
183
192
184
193
185
194
def collect_operator_stats (program = None , print_subblocks = False ):
195
+ """
196
+ Collect the number of operators for different data types through parsing
197
+ the program. The statistical data are categorized according to four data
198
+ types, namely float32, float16, bfloat16 and others.
199
+
200
+ Args:
201
+ program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
202
+ print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.
203
+
204
+ Examples:
205
+
206
+ .. code-block:: python
207
+
208
+ import paddle
209
+
210
+ paddle.enable_static()
211
+
212
+ class SimpleConvNet(paddle.nn.Layer):
213
+ def __init__(self):
214
+ super().__init__()
215
+ self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
216
+ self.linear = paddle.nn.Linear(in_features=26, out_features=10)
217
+
218
+ def forward(self, x):
219
+ out = self.conv(x)
220
+ out = paddle.nn.functional.relu(out)
221
+ out = self.linear(out)
222
+ out = paddle.nn.functional.softmax(out)
223
+ return out
224
+
225
+ main_program = paddle.static.Program()
226
+ startup_program = paddle.static.Program()
227
+ with paddle.utils.unique_name.guard():
228
+ with paddle.static.program_guard(main_program, startup_program):
229
+ model = SimpleConvNet()
230
+ x = paddle.static.data(
231
+ name='input', shape=[None, 1, 28, 28], dtype='float32'
232
+ )
233
+ out = model(x)
234
+ loss = paddle.mean(out)
235
+ optimizer = paddle.optimizer.AdamW()
236
+ optimizer = paddle.static.amp.decorate(optimizer)
237
+ optimizer.minimize(loss)
238
+ paddle.static.amp.debugging.collect_operator_stats(main_program)
239
+ # <------------------------------------------------ op list of all blocks ------------------------------------------------->
240
+ # <------------------------------------------------------- op list -------------------------------------------------------->
241
+ # <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
242
+ # adamw | 0 | 0 | 4 | 0
243
+ # cast | 5 | 0 | 6 | 0
244
+ # check_finite_and_unscale | 0 | 0 | 1 | 0
245
+ # conv2d | 1 | 0 | 0 | 0
246
+ # conv2d_grad | 1 | 0 | 0 | 0
247
+ # elementwise_add | 2 | 0 | 0 | 0
248
+ # elementwise_add_grad | 2 | 0 | 0 | 0
249
+ # elementwise_mul | 0 | 0 | 1 | 0
250
+ # elementwise_mul_grad | 0 | 0 | 1 | 0
251
+ # fill_constant | 0 | 0 | 1 | 0
252
+ # matmul_v2 | 1 | 0 | 0 | 0
253
+ # matmul_v2_grad | 1 | 0 | 0 | 0
254
+ # memcpy | 0 | 0 | 0 | 1
255
+ # reduce_mean | 0 | 0 | 1 | 0
256
+ # reduce_mean_grad | 0 | 0 | 1 | 0
257
+ # relu | 1 | 0 | 0 | 0
258
+ # relu_grad | 1 | 0 | 0 | 0
259
+ # reshape2 | 0 | 0 | 1 | 0
260
+ # reshape2_grad | 0 | 0 | 1 | 0
261
+ # softmax | 0 | 0 | 1 | 0
262
+ # softmax_grad | 0 | 0 | 1 | 0
263
+ # update_loss_scaling | 0 | 0 | 1 | 0
264
+ # <----------------------------------------------------- op count: 22 ----------------------------------------------------->
265
+ """
266
+
186
267
def _convert_to_list (op_stats_unit_dict ):
187
268
for key , value in op_stats_unit_dict .items ():
188
269
op_stats_unit_dict [key ] = value .convert_to_list ()
0 commit comments