Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 6fc6a2b

Browse files
committed
fix paddle duplicate var problem
1 parent 4f6836a commit 6fc6a2b

File tree

1 file changed

+60
-57
lines changed

1 file changed

+60
-57
lines changed

python/tests/conv2d_utils.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,70 @@
55

66

77
def conv2d_native(inputs_data, input_shape, filter_size, attrs, is_depthwise):
8-
padding = [0, 0]
9-
stride = [1, 1]
10-
dilation = [1, 1]
11-
data_format = "NCHW"
12-
groups = 1
13-
for key in attrs.attr_store:
14-
if key == "stride":
15-
stride = attrs.get_attr("stride")
16-
elif key == "padding":
17-
padding = attrs.get_attr("padding")
18-
elif key == "dilation":
19-
dilation = attrs.get_attr("dilation")
20-
elif key == "groups":
21-
groups = attrs.get_attr("groups")
22-
elif key == "data_format":
23-
data_format = attrs.get_attr("data_format")
24-
else:
25-
raise ValueError("attr_store {} is not supported".format(key))
8+
main_program = fluid.Program()
9+
with fluid.program_guard(main_program, fluid.Program()):
10+
padding = [0, 0]
11+
stride = [1, 1]
12+
dilation = [1, 1]
13+
data_format = "NCHW"
14+
groups = 1
15+
for key in attrs.attr_store:
16+
if key == "stride":
17+
stride = attrs.get_attr("stride")
18+
elif key == "padding":
19+
padding = attrs.get_attr("padding")
20+
elif key == "dilation":
21+
dilation = attrs.get_attr("dilation")
22+
elif key == "groups":
23+
groups = attrs.get_attr("groups")
24+
elif key == "data_format":
25+
data_format = attrs.get_attr("data_format")
26+
else:
27+
raise ValueError("attr_store {} is not supported".format(key))
2628

27-
img = fluid.layers.data(name='img', shape=input_shape[1:], dtype='float32')
28-
if is_depthwise:
29-
if data_format == "NCHW":
30-
cin_index = 1
29+
img = fluid.layers.data(name='img', shape=input_shape[1:], dtype='float32')
30+
if is_depthwise:
31+
if data_format == "NCHW":
32+
cin_index = 1
33+
else:
34+
cin_index = 3
35+
filter_size_new = [
36+
filter_size[1] * input_shape[cin_index], filter_size[0] // groups,
37+
filter_size[2], filter_size[3]
38+
]
3139
else:
32-
cin_index = 3
33-
filter_size_new = [
34-
filter_size[1] * input_shape[cin_index], filter_size[0] // groups,
35-
filter_size[2], filter_size[3]
36-
]
37-
else:
38-
filter_size_new = filter_size
39-
param = fluid.initializer.NumpyArrayInitializer(
40-
np.array(inputs_data[1]).reshape(filter_size_new).astype("float32"))
41-
# filter: (c_out, c_in // group, kernel_h, kernel_w)
42-
filter_hw = list(filter_size_new[2:4])
43-
if data_format == "NHWC":
44-
filter_hw = list(filter_size_new[1:3])
45-
if isinstance(stride, int):
46-
stride = [stride.copy(), stride.copy()]
47-
if isinstance(padding, int):
48-
padding = [padding.copy(), padding.copy()]
49-
if isinstance(dilation, int):
50-
dilation = [dilation.copy(), dilation.copy()]
40+
filter_size_new = filter_size
41+
param = fluid.initializer.NumpyArrayInitializer(
42+
np.array(inputs_data[1]).reshape(filter_size_new).astype("float32"))
43+
# filter: (c_out, c_in // group, kernel_h, kernel_w)
44+
filter_hw = list(filter_size_new[2:4])
45+
if data_format == "NHWC":
46+
filter_hw = list(filter_size_new[1:3])
47+
if isinstance(stride, int):
48+
stride = [stride.copy(), stride.copy()]
49+
if isinstance(padding, int):
50+
padding = [padding.copy(), padding.copy()]
51+
if isinstance(dilation, int):
52+
dilation = [dilation.copy(), dilation.copy()]
53+
54+
res = fluid.layers.conv2d(
55+
input=img,
56+
num_filters=filter_size_new[0],
57+
filter_size=filter_hw,
58+
stride=stride,
59+
padding=padding,
60+
dilation=dilation,
61+
groups=groups,
62+
param_attr=param,
63+
data_format=data_format)
64+
exe = fluid.Executor(fluid.CPUPlace())
65+
exe.run(fluid.default_startup_program())
5166

52-
res = fluid.layers.conv2d(
53-
input=img,
54-
num_filters=filter_size_new[0],
55-
filter_size=filter_hw,
56-
stride=stride,
57-
padding=padding,
58-
dilation=dilation,
59-
groups=groups,
60-
param_attr=param,
61-
data_format=data_format)
62-
exe = fluid.Executor(fluid.CPUPlace())
63-
exe.run(fluid.default_startup_program())
67+
x = np.array(inputs_data[0]).reshape(input_shape).astype("float32")
68+
output = exe.run(feed={"img": x}, fetch_list=[res])
69+
output = np.array(output)
70+
print("output's shape is:", output.shape)
6471

65-
x = np.array(inputs_data[0]).reshape(input_shape).astype("float32")
66-
output = exe.run(feed={"img": x}, fetch_list=[res])
67-
output = np.array(output)
68-
print("output's shape is:", output.shape)
6972
res_shape = output.shape[1:]
7073
pad_shape = list(input_shape)
7174
dialtion_shape = list(filter_size_new)

0 commit comments

Comments
 (0)