Skip to content

Commit 91879f5

Browse files
committed
fix fro pre-commit
1 parent 2fc9ffd commit 91879f5

File tree

3 files changed

+613
-217
lines changed

3 files changed

+613
-217
lines changed

x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py

+123-51
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
3737
tensor_value = value
3838
value = "{}".format(value)
3939
if "tensor" in value:
40-
if isinstance(tensor_value, list) or isinstance(tensor_value, tuple):
40+
if isinstance(tensor_value, list) or isinstance(tensor_value,
41+
tuple):
4142
name_dict = dict()
4243
for i, tv in enumerate(tensor_value):
43-
output_name_i = "{}_p{}".format(output_name,i)
44+
output_name_i = "{}_p{}".format(output_name, i)
4445
key_i = "input{}".format(i)
45-
mapper.paddle_params[output_name_i] = tv.cpu().detach().numpy()
46+
mapper.paddle_params[output_name_i] = tv.cpu().detach(
47+
).numpy()
4648
graph.add_layer(
4749
"self.create_parameter",
4850
inputs={},
4951
outputs=[output_name_i],
5052
scope_name=scope_name,
51-
dtype=string(str(mapper.paddle_params[output_name_i].dtype)),
52-
shape = mapper.paddle_params[output_name_i].shape,
53-
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
53+
dtype=string(
54+
str(mapper.paddle_params[output_name_i].dtype)),
55+
shape=mapper.paddle_params[output_name_i].shape,
56+
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
57+
)
5458
name_dict[key_i] = output_name_i
5559
graph.add_layer(
5660
"prim.list",
@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
5963
scope_name=scope_name)
6064
return [], [output_name]
6165
else:
62-
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
63-
mapper.paddle_params[output_name] = tensor_value.cpu().detach().numpy()
66+
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
67+
mapper.paddle_params[output_name] = tensor_value.cpu().detach(
68+
).numpy()
6469
graph.add_layer(
65-
"self.create_parameter",
66-
inputs={},
67-
outputs=[output_name],
68-
scope_name=scope_name,
69-
dtype=string(str(mapper.paddle_params[output_name].dtype)),
70-
shape = mapper.paddle_params[output_name].shape,
71-
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
70+
"self.create_parameter",
71+
inputs={},
72+
outputs=[output_name],
73+
scope_name=scope_name,
74+
dtype=string(str(mapper.paddle_params[output_name].dtype)),
75+
shape=mapper.paddle_params[output_name].shape,
76+
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
77+
)
7278
return [], [output_name]
7379
if "inf" in str(value):
7480
t = str(type(value)).split("'")[1]
@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
8187
value = int(math.pow(2, 31) - 1)
8288
mapper.attrs[output_name] = value
8389
graph.add_layer(
84-
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value)
90+
"prim.constant",
91+
inputs={},
92+
outputs=[output_name],
93+
scope_name=scope_name,
94+
value=value)
8595
return [], [output_name]
8696

8797

@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node):
105115
# 获取当前节点输出的list
106116
current_outputs = [output_name]
107117
# 处理输入0,即%4336
108-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
118+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
119+
scope_name)
109120
layer_inputs["input"] = inputs_name[0]
110121
# 获取当前节点输入的list
111122
current_inputs = list(layer_inputs.values())
112123

113-
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
124+
graph.add_layer(
125+
"prim.equal",
126+
inputs=layer_inputs,
127+
outputs=layer_outputs,
128+
scope_name=scope_name)
114129
return current_inputs, current_outputs
115130

116131

117132
def prim_DictConstruct(mapper, graph, node):
118133
""" 构建dict。
119-
134+
120135
TorchScript示例:
121136
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
122137
参数含义:
@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
136151
current_outputs = [output_name]
137152
# 处理每个输入
138153
for i, input_name in enumerate(inputs_name):
139-
if i%2 == 0:
140-
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name]
154+
if i % 2 == 0:
155+
layer_attrs["key{}".format(int(i / 2))] = mapper.attrs[input_name]
141156
else:
142-
layer_inputs["value{}".format(int(i/2))] = input_name
157+
layer_inputs["value{}".format(int(i / 2))] = input_name
143158
# 获取当前节点输入的list
144159
current_inputs = list(layer_inputs.values())
145160

146-
graph.add_layer("prim.dict_construct",
147-
inputs=layer_inputs,
148-
outputs=layer_outputs,
149-
scope_name=scope_name,
150-
**layer_attrs)
161+
graph.add_layer(
162+
"prim.dict_construct",
163+
inputs=layer_inputs,
164+
outputs=layer_outputs,
165+
scope_name=scope_name,
166+
**layer_attrs)
151167
return current_inputs, current_outputs
152168

153169

154-
155170
def prim_GetAttr(mapper, graph, node):
156171
""" 获取attribute信息。
157172
@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
212227
input_node = list(node.inputs())[0].node()
213228
script_input_unique_id = list(node.inputs())[0].unique()
214229
input_node_name = mapper.outputs_info[script_input_unique_id]
215-
mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name)
216-
graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name)
230+
mapper._check_input(graph, input_node, input_node_name, current_outputs,
231+
scope_name)
232+
graph.add_layer(
233+
"prim.if",
234+
inputs={'input': input_node_name},
235+
outputs=node_outputs,
236+
scope_name=scope_name)
217237
current_layer = list(graph.layers.values())[-1]
218238
block0 = list(node.blocks())[0]
219239
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
249269
current_outputs = [output_name]
250270
# 处理每个输入
251271
for i, input_name in enumerate(inputs_name):
252-
mapper._check_input(graph, inputs_node[i], input_name, current_outputs, scope_name)
272+
mapper._check_input(graph, inputs_node[i], input_name, current_outputs,
273+
scope_name)
253274
layer_inputs["input{}".format(i)] = input_name
254275
# 获取当前节点输入的list
255276
current_inputs = list(layer_inputs.values())
256277

257-
layer_id = graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
278+
layer_id = graph.add_layer(
279+
"prim.list",
280+
inputs=layer_inputs,
281+
outputs=layer_outputs,
282+
scope_name=scope_name)
258283
mapper.output2id[output_name] = layer_id
259284
return current_inputs, current_outputs
260285

@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
277302
# 获取当前节点输出的list
278303
current_outputs = layer_outputs.copy()
279304
# 处理输入0,即%4354
280-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
305+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
306+
scope_name)
281307
layer_inputs["input"] = inputs_name[0]
282308
# 获取当前节点输入的list
283309
current_inputs = list(layer_inputs.values())
284310

285311
graph.add_layer(
286-
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
312+
"prim.list_unpack",
313+
inputs=layer_inputs,
314+
outputs=layer_outputs,
315+
scope_name=scope_name)
287316
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
288317
return current_inputs, current_outputs
289318

@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
342371
scope_name=scope_name)
343372
node_outputs.append(block_input_node_name)
344373

345-
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name)
374+
graph.add_layer(
375+
"prim.loop",
376+
inputs=loop_inputs,
377+
outputs=loop_outputs,
378+
scope_name=scope_name)
346379
current_layer = list(graph.layers.values())[-1]
347380
block_graph, graph_inputs = mapper.traverse(block, current_layer)
348381
for i, input_name in enumerate(graph_inputs):
@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
370403
# 获取当前节点输出的list
371404
current_outputs = [output_name]
372405
# 处理输入0,即%86
373-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
406+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
407+
scope_name)
374408
layer_inputs["input"] = inputs_name[0]
375409
# 获取当前节点输入的list
376410
current_inputs = list(layer_inputs.values())
377411

378-
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
412+
graph.add_layer(
413+
"prim.min",
414+
inputs=layer_inputs,
415+
outputs=layer_outputs,
416+
scope_name=scope_name)
379417
return current_inputs, current_outputs
380418

381419

@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
397435
# 获取当前节点输出的list
398436
current_outputs = [output_name]
399437
# 处理输入0,即%86
400-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
401-
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0])
438+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
439+
scope_name)
440+
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(
441+
inputs_node[0])
402442
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
403443
layer_inputs["input"] = inputs_name[0]
404444
# 获取当前节点输入的list
405445
current_inputs = list(layer_inputs.values())
406446
graph.add_layer(
407-
"prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
447+
"prim_equal",
448+
inputs=layer_inputs,
449+
outputs=layer_outputs,
450+
scope_name=scope_name)
408451
else:
409452
layer_inputs["fill_value"] = inputs_name[0]
410453
# 获取当前节点输入的list
@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
437480
# 获取当前节点输出的list
438481
current_outputs = [output_name]
439482
# 处理输入0,即%76
440-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
483+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
484+
scope_name)
441485
layer_inputs["input"] = inputs_name[0]
442486
# 获取当前节点输入的list
443487
current_inputs = list(layer_inputs.values())
444488

445489
graph.add_layer(
446-
"prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
490+
"prim.exception",
491+
inputs=layer_inputs,
492+
outputs=layer_outputs,
493+
scope_name=scope_name)
447494
return current_inputs, current_outputs
448495

449496

@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
464511
# 获取当前节点输出的list
465512
current_outputs = [output_name]
466513
# 处理输入0,即%86
467-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
514+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
515+
scope_name)
468516
layer_inputs["input"] = inputs_name[0]
469517
# 获取当前节点输入的list
470518
current_inputs = list(layer_inputs.values())
471519

472520
graph.add_layer(
473-
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
521+
"prim.requires_grad",
522+
inputs=layer_inputs,
523+
outputs=layer_outputs,
524+
scope_name=scope_name)
474525
return current_inputs, current_outputs
475526

476527

@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
527578
# 获取当前节点输出的list
528579
current_outputs = [output_name]
529580
# 处理输入0,即%input.8
530-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
581+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
582+
scope_name)
531583
layer_inputs["input"] = inputs_name[0]
532584
# 获取当前节点输入的list
533585
current_inputs = list(layer_inputs.values())
534586

535587
graph.add_layer(
536-
"paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
588+
"paddle.shape",
589+
inputs=layer_inputs,
590+
outputs=layer_outputs,
591+
scope_name=scope_name)
537592
return current_inputs, current_outputs
538593

539594

@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
560615
# 获取当前节点输入的list
561616
current_inputs = list(layer_inputs.values())
562617

563-
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
618+
graph.add_layer(
619+
"prim.tuple",
620+
inputs=layer_inputs,
621+
outputs=layer_outputs,
622+
scope_name=scope_name)
564623
return current_inputs, current_outputs
565624

566625

@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
590649
current_inputs = list(layer_inputs.values())
591650

592651
graph.add_layer(
593-
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs)
652+
"prim.tuple_unpack",
653+
inputs=layer_inputs,
654+
outputs=layer_outputs,
655+
scope_name=scope_name,
656+
**layer_attrs)
594657
return current_inputs, current_outputs
595658

596659

@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
614677
# 获取当前节点输出的list
615678
current_outputs = [output_name]
616679
# 处理输入0,即%size.63
617-
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
680+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
681+
scope_name)
618682
layer_inputs["input"] = inputs_name[0]
619683
# 获取当前节点输入的list
620684
current_inputs = list(layer_inputs.values())
621685

622-
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
686+
graph.add_layer(
687+
"prim.equal",
688+
inputs=layer_inputs,
689+
outputs=layer_outputs,
690+
scope_name=scope_name)
623691
return current_inputs, current_outputs
624692

625693

@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
636704
output = list(node.outputs())[0]
637705
mapper.attrs[output_name] = None
638706
graph.add_layer(
639-
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None)
707+
"prim.constant",
708+
inputs={},
709+
outputs=[output_name],
710+
scope_name=scope_name,
711+
value=None)
640712
return [], [output_name]

0 commit comments

Comments
 (0)