@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
37
37
tensor_value = value
38
38
value = "{}" .format (value )
39
39
if "tensor" in value :
40
- if isinstance (tensor_value , list ) or isinstance (tensor_value , tuple ):
40
+ if isinstance (tensor_value , list ) or isinstance (tensor_value ,
41
+ tuple ):
41
42
name_dict = dict ()
42
43
for i , tv in enumerate (tensor_value ):
43
- output_name_i = "{}_p{}" .format (output_name ,i )
44
+ output_name_i = "{}_p{}" .format (output_name , i )
44
45
key_i = "input{}" .format (i )
45
- mapper .paddle_params [output_name_i ] = tv .cpu ().detach ().numpy ()
46
+ mapper .paddle_params [output_name_i ] = tv .cpu ().detach (
47
+ ).numpy ()
46
48
graph .add_layer (
47
49
"self.create_parameter" ,
48
50
inputs = {},
49
51
outputs = [output_name_i ],
50
52
scope_name = scope_name ,
51
- dtype = string (str (mapper .paddle_params [output_name_i ].dtype )),
52
- shape = mapper .paddle_params [output_name_i ].shape ,
53
- default_initializer = "paddle.nn.initializer.Constant(value=0.0)" )
53
+ dtype = string (
54
+ str (mapper .paddle_params [output_name_i ].dtype )),
55
+ shape = mapper .paddle_params [output_name_i ].shape ,
56
+ default_initializer = "paddle.nn.initializer.Constant(value=0.0)"
57
+ )
54
58
name_dict [key_i ] = output_name_i
55
59
graph .add_layer (
56
60
"prim.list" ,
@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
59
63
scope_name = scope_name )
60
64
return [], [output_name ]
61
65
else :
62
- # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
63
- mapper .paddle_params [output_name ] = tensor_value .cpu ().detach ().numpy ()
66
+ # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
67
+ mapper .paddle_params [output_name ] = tensor_value .cpu ().detach (
68
+ ).numpy ()
64
69
graph .add_layer (
65
- "self.create_parameter" ,
66
- inputs = {},
67
- outputs = [output_name ],
68
- scope_name = scope_name ,
69
- dtype = string (str (mapper .paddle_params [output_name ].dtype )),
70
- shape = mapper .paddle_params [output_name ].shape ,
71
- default_initializer = "paddle.nn.initializer.Constant(value=0.0)" )
70
+ "self.create_parameter" ,
71
+ inputs = {},
72
+ outputs = [output_name ],
73
+ scope_name = scope_name ,
74
+ dtype = string (str (mapper .paddle_params [output_name ].dtype )),
75
+ shape = mapper .paddle_params [output_name ].shape ,
76
+ default_initializer = "paddle.nn.initializer.Constant(value=0.0)"
77
+ )
72
78
return [], [output_name ]
73
79
if "inf" in str (value ):
74
80
t = str (type (value )).split ("'" )[1 ]
@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
81
87
value = int (math .pow (2 , 31 ) - 1 )
82
88
mapper .attrs [output_name ] = value
83
89
graph .add_layer (
84
- "prim.constant" , inputs = {}, outputs = [output_name ], scope_name = scope_name , value = value )
90
+ "prim.constant" ,
91
+ inputs = {},
92
+ outputs = [output_name ],
93
+ scope_name = scope_name ,
94
+ value = value )
85
95
return [], [output_name ]
86
96
87
97
@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node):
105
115
# 获取当前节点输出的list
106
116
current_outputs = [output_name ]
107
117
# 处理输入0,即%4336
108
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
118
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
119
+ scope_name )
109
120
layer_inputs ["input" ] = inputs_name [0 ]
110
121
# 获取当前节点输入的list
111
122
current_inputs = list (layer_inputs .values ())
112
123
113
- graph .add_layer ("prim.equal" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
124
+ graph .add_layer (
125
+ "prim.equal" ,
126
+ inputs = layer_inputs ,
127
+ outputs = layer_outputs ,
128
+ scope_name = scope_name )
114
129
return current_inputs , current_outputs
115
130
116
131
117
132
def prim_DictConstruct (mapper , graph , node ):
118
133
""" 构建dict。
119
-
134
+
120
135
TorchScript示例:
121
136
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
122
137
参数含义:
@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
136
151
current_outputs = [output_name ]
137
152
# 处理每个输入
138
153
for i , input_name in enumerate (inputs_name ):
139
- if i % 2 == 0 :
140
- layer_attrs ["key{}" .format (int (i / 2 ))] = mapper .attrs [input_name ]
154
+ if i % 2 == 0 :
155
+ layer_attrs ["key{}" .format (int (i / 2 ))] = mapper .attrs [input_name ]
141
156
else :
142
- layer_inputs ["value{}" .format (int (i / 2 ))] = input_name
157
+ layer_inputs ["value{}" .format (int (i / 2 ))] = input_name
143
158
# 获取当前节点输入的list
144
159
current_inputs = list (layer_inputs .values ())
145
160
146
- graph .add_layer ("prim.dict_construct" ,
147
- inputs = layer_inputs ,
148
- outputs = layer_outputs ,
149
- scope_name = scope_name ,
150
- ** layer_attrs )
161
+ graph .add_layer (
162
+ "prim.dict_construct" ,
163
+ inputs = layer_inputs ,
164
+ outputs = layer_outputs ,
165
+ scope_name = scope_name ,
166
+ ** layer_attrs )
151
167
return current_inputs , current_outputs
152
168
153
169
154
-
155
170
def prim_GetAttr (mapper , graph , node ):
156
171
""" 获取attribute信息。
157
172
@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
212
227
input_node = list (node .inputs ())[0 ].node ()
213
228
script_input_unique_id = list (node .inputs ())[0 ].unique ()
214
229
input_node_name = mapper .outputs_info [script_input_unique_id ]
215
- mapper ._check_input (graph , input_node , input_node_name , current_outputs , scope_name )
216
- graph .add_layer ("prim.if" , inputs = {'input' : input_node_name }, outputs = node_outputs , scope_name = scope_name )
230
+ mapper ._check_input (graph , input_node , input_node_name , current_outputs ,
231
+ scope_name )
232
+ graph .add_layer (
233
+ "prim.if" ,
234
+ inputs = {'input' : input_node_name },
235
+ outputs = node_outputs ,
236
+ scope_name = scope_name )
217
237
current_layer = list (graph .layers .values ())[- 1 ]
218
238
block0 = list (node .blocks ())[0 ]
219
239
block0_graph , graph_inputs0 = mapper .traverse (block0 , current_layer )
@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
249
269
current_outputs = [output_name ]
250
270
# 处理每个输入
251
271
for i , input_name in enumerate (inputs_name ):
252
- mapper ._check_input (graph , inputs_node [i ], input_name , current_outputs , scope_name )
272
+ mapper ._check_input (graph , inputs_node [i ], input_name , current_outputs ,
273
+ scope_name )
253
274
layer_inputs ["input{}" .format (i )] = input_name
254
275
# 获取当前节点输入的list
255
276
current_inputs = list (layer_inputs .values ())
256
277
257
- layer_id = graph .add_layer ("prim.list" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
278
+ layer_id = graph .add_layer (
279
+ "prim.list" ,
280
+ inputs = layer_inputs ,
281
+ outputs = layer_outputs ,
282
+ scope_name = scope_name )
258
283
mapper .output2id [output_name ] = layer_id
259
284
return current_inputs , current_outputs
260
285
@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
277
302
# 获取当前节点输出的list
278
303
current_outputs = layer_outputs .copy ()
279
304
# 处理输入0,即%4354
280
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
305
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
306
+ scope_name )
281
307
layer_inputs ["input" ] = inputs_name [0 ]
282
308
# 获取当前节点输入的list
283
309
current_inputs = list (layer_inputs .values ())
284
310
285
311
graph .add_layer (
286
- "prim.list_unpack" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
312
+ "prim.list_unpack" ,
313
+ inputs = layer_inputs ,
314
+ outputs = layer_outputs ,
315
+ scope_name = scope_name )
287
316
mapper .split_len [list (layer_inputs .values ())[0 ]] = len (layer_outputs )
288
317
return current_inputs , current_outputs
289
318
@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
342
371
scope_name = scope_name )
343
372
node_outputs .append (block_input_node_name )
344
373
345
- graph .add_layer ("prim.loop" , inputs = loop_inputs , outputs = loop_outputs , scope_name = scope_name )
374
+ graph .add_layer (
375
+ "prim.loop" ,
376
+ inputs = loop_inputs ,
377
+ outputs = loop_outputs ,
378
+ scope_name = scope_name )
346
379
current_layer = list (graph .layers .values ())[- 1 ]
347
380
block_graph , graph_inputs = mapper .traverse (block , current_layer )
348
381
for i , input_name in enumerate (graph_inputs ):
@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
370
403
# 获取当前节点输出的list
371
404
current_outputs = [output_name ]
372
405
# 处理输入0,即%86
373
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
406
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
407
+ scope_name )
374
408
layer_inputs ["input" ] = inputs_name [0 ]
375
409
# 获取当前节点输入的list
376
410
current_inputs = list (layer_inputs .values ())
377
411
378
- graph .add_layer ("prim.min" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
412
+ graph .add_layer (
413
+ "prim.min" ,
414
+ inputs = layer_inputs ,
415
+ outputs = layer_outputs ,
416
+ scope_name = scope_name )
379
417
return current_inputs , current_outputs
380
418
381
419
@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
397
435
# 获取当前节点输出的list
398
436
current_outputs = [output_name ]
399
437
# 处理输入0,即%86
400
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
401
- inputs_inputs_name , inputs_inputs_node = mapper ._get_inputs_name (inputs_node [0 ])
438
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
439
+ scope_name )
440
+ inputs_inputs_name , inputs_inputs_node = mapper ._get_inputs_name (
441
+ inputs_node [0 ])
402
442
if inputs_node [0 ].kind () == "aten::size" and len (inputs_inputs_name ) > 1 :
403
443
layer_inputs ["input" ] = inputs_name [0 ]
404
444
# 获取当前节点输入的list
405
445
current_inputs = list (layer_inputs .values ())
406
446
graph .add_layer (
407
- "prim_equal" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
447
+ "prim_equal" ,
448
+ inputs = layer_inputs ,
449
+ outputs = layer_outputs ,
450
+ scope_name = scope_name )
408
451
else :
409
452
layer_inputs ["fill_value" ] = inputs_name [0 ]
410
453
# 获取当前节点输入的list
@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
437
480
# 获取当前节点输出的list
438
481
current_outputs = [output_name ]
439
482
# 处理输入0,即%76
440
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
483
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
484
+ scope_name )
441
485
layer_inputs ["input" ] = inputs_name [0 ]
442
486
# 获取当前节点输入的list
443
487
current_inputs = list (layer_inputs .values ())
444
488
445
489
graph .add_layer (
446
- "prim.exception" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
490
+ "prim.exception" ,
491
+ inputs = layer_inputs ,
492
+ outputs = layer_outputs ,
493
+ scope_name = scope_name )
447
494
return current_inputs , current_outputs
448
495
449
496
@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
464
511
# 获取当前节点输出的list
465
512
current_outputs = [output_name ]
466
513
# 处理输入0,即%86
467
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
514
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
515
+ scope_name )
468
516
layer_inputs ["input" ] = inputs_name [0 ]
469
517
# 获取当前节点输入的list
470
518
current_inputs = list (layer_inputs .values ())
471
519
472
520
graph .add_layer (
473
- "prim.requires_grad" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
521
+ "prim.requires_grad" ,
522
+ inputs = layer_inputs ,
523
+ outputs = layer_outputs ,
524
+ scope_name = scope_name )
474
525
return current_inputs , current_outputs
475
526
476
527
@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
527
578
# 获取当前节点输出的list
528
579
current_outputs = [output_name ]
529
580
# 处理输入0,即%input.8
530
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
581
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
582
+ scope_name )
531
583
layer_inputs ["input" ] = inputs_name [0 ]
532
584
# 获取当前节点输入的list
533
585
current_inputs = list (layer_inputs .values ())
534
586
535
587
graph .add_layer (
536
- "paddle.shape" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
588
+ "paddle.shape" ,
589
+ inputs = layer_inputs ,
590
+ outputs = layer_outputs ,
591
+ scope_name = scope_name )
537
592
return current_inputs , current_outputs
538
593
539
594
@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
560
615
# 获取当前节点输入的list
561
616
current_inputs = list (layer_inputs .values ())
562
617
563
- graph .add_layer ("prim.tuple" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
618
+ graph .add_layer (
619
+ "prim.tuple" ,
620
+ inputs = layer_inputs ,
621
+ outputs = layer_outputs ,
622
+ scope_name = scope_name )
564
623
return current_inputs , current_outputs
565
624
566
625
@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
590
649
current_inputs = list (layer_inputs .values ())
591
650
592
651
graph .add_layer (
593
- "prim.tuple_unpack" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name , ** layer_attrs )
652
+ "prim.tuple_unpack" ,
653
+ inputs = layer_inputs ,
654
+ outputs = layer_outputs ,
655
+ scope_name = scope_name ,
656
+ ** layer_attrs )
594
657
return current_inputs , current_outputs
595
658
596
659
@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
614
677
# 获取当前节点输出的list
615
678
current_outputs = [output_name ]
616
679
# 处理输入0,即%size.63
617
- mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs , scope_name )
680
+ mapper ._check_input (graph , inputs_node [0 ], inputs_name [0 ], current_outputs ,
681
+ scope_name )
618
682
layer_inputs ["input" ] = inputs_name [0 ]
619
683
# 获取当前节点输入的list
620
684
current_inputs = list (layer_inputs .values ())
621
685
622
- graph .add_layer ("prim.equal" , inputs = layer_inputs , outputs = layer_outputs , scope_name = scope_name )
686
+ graph .add_layer (
687
+ "prim.equal" ,
688
+ inputs = layer_inputs ,
689
+ outputs = layer_outputs ,
690
+ scope_name = scope_name )
623
691
return current_inputs , current_outputs
624
692
625
693
@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
636
704
output = list (node .outputs ())[0 ]
637
705
mapper .attrs [output_name ] = None
638
706
graph .add_layer (
639
- "prim.constant" , inputs = {}, outputs = [output_name ], scope_name = scope_name , value = None )
707
+ "prim.constant" ,
708
+ inputs = {},
709
+ outputs = [output_name ],
710
+ scope_name = scope_name ,
711
+ value = None )
640
712
return [], [output_name ]
0 commit comments