Skip to content

Commit f2a3405

Browse files
authored
fix save inference model conditional op (#37579) (#38739)
1 parent 5925b82 commit f2a3405

File tree

2 files changed

+167
-25
lines changed

2 files changed

+167
-25
lines changed

paddle/fluid/framework/prune.cc

+19-25
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,23 @@ int FindMapByValue(const std::map<int, int>& m, int val) {
145145
return -1;
146146
}
147147

148+
// In other two cases,the op that has feed vars as output vars is dependent:
149+
// 1. op has subblock, like while/for/ifelse/recurrent
150+
// 2. op is in subblock
151+
bool IsSubBlockDependent(const proto::OpDesc& op_desc,
152+
const std::set<std::string>& feed_vars,
153+
int parent_block_id) {
154+
for (auto& var : op_desc.outputs()) {
155+
for (auto& argu : var.arguments()) {
156+
if ((HasSubBlock(op_desc) || parent_block_id != -1) &&
157+
feed_vars.count(argu) != 0) {
158+
return true;
159+
}
160+
}
161+
}
162+
return false;
163+
}
164+
148165
// block_id is the idx of the current block in the input desc
149166
// parent_block_id is the idx of the parent of the current block
150167
// in the output desc, -1 means the current block is global block
@@ -210,7 +227,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
210227
// }
211228

212229
if (IsTarget(op_desc) ||
213-
(HasDependentOutputVar(op_desc, *dependent_vars) &&
230+
((HasDependentOutputVar(op_desc, *dependent_vars) ||
231+
(IsSubBlockDependent(op_desc, feed_var_names, parent_block_id))) &&
214232
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
215233
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
216234
// inputs and output, it may introduce wrong dependency graph.
@@ -227,30 +245,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
227245
should_run.push_back(true);
228246
} else {
229247
should_run.push_back(false);
230-
// If the output of an op modifies feed vars, the op should not clip.
231-
// For example, in the transformer structure, the third parameter returned
232-
// by beam_search op is generally assigned to a feed var. Cutting the
233-
// assign op will cause an error.
234-
if (parent_block_id != -1) {
235-
bool flag = false;
236-
for (auto& var : op_desc.outputs()) {
237-
for (auto& argu : var.arguments()) {
238-
if (feed_var_names.count(argu)) {
239-
flag = true;
240-
}
241-
}
242-
}
243-
if (flag) {
244-
should_run.back() = true;
245-
246-
// If any op should run, then there inputs are dependent_vars
247-
for (auto& var : op_desc.inputs()) {
248-
for (auto& argu : var.arguments()) {
249-
dependent_vars->insert(argu);
250-
}
251-
}
252-
}
253-
}
254248
}
255249
}
256250

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import os
18+
import unittest
19+
import numpy as np
20+
21+
import paddle
22+
import paddle.fluid as fluid
23+
import paddle.nn.functional as F
24+
25+
26+
def getModelOp(model_path):
27+
model_bytes = paddle.static.load_from_file(model_path)
28+
pg = paddle.static.deserialize_program(model_bytes)
29+
main_block = pg.desc.block(0)
30+
size = main_block.op_size()
31+
32+
result = set()
33+
for i in range(0, size):
34+
#print(main_block.op(i).type())
35+
result.add(main_block.op(i).type())
36+
37+
return result
38+
39+
40+
class WhileNet(paddle.nn.Layer):
41+
def __init__(self):
42+
super(WhileNet, self).__init__()
43+
44+
def forward(self, x):
45+
y = paddle.rand(shape=[1, 3, 4, 4])
46+
47+
w1 = paddle.shape(y)[0]
48+
w2 = paddle.shape(x)[0]
49+
50+
while w2 != w1:
51+
x = F.avg_pool2d(x, kernel_size=3, padding=1, stride=2)
52+
w2 = paddle.shape(x)[0]
53+
54+
return x + y
55+
56+
57+
class ForNet(paddle.nn.Layer):
58+
def __init__(self):
59+
super(ForNet, self).__init__()
60+
61+
def forward(self, x):
62+
y = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
63+
z = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
64+
for i in range(0, z):
65+
x = x + i
66+
67+
return x + y
68+
69+
70+
class IfElseNet(paddle.nn.Layer):
71+
def __init__(self):
72+
super(IfElseNet, self).__init__()
73+
74+
def forward(self, x):
75+
y = paddle.to_tensor([5])
76+
if x > y:
77+
x = x + 1
78+
else:
79+
x = x - 1
80+
return x
81+
82+
83+
class TestConditionalOp(unittest.TestCase):
84+
def test_while_op(self):
85+
paddle.disable_static()
86+
net = WhileNet()
87+
net = paddle.jit.to_static(
88+
net,
89+
input_spec=[
90+
paddle.static.InputSpec(
91+
shape=[1, 3, 8, 8], dtype='float32')
92+
])
93+
paddle.jit.save(net, './while_net')
94+
95+
right_pdmodel = set([
96+
"uniform_random", "shape", "slice", "not_equal", "while",
97+
"elementwise_add"
98+
])
99+
paddle.enable_static()
100+
pdmodel = getModelOp("while_net.pdmodel")
101+
#print(len(right_pdmodel.difference(pdmodel)))
102+
self.assertTrue(
103+
len(right_pdmodel.difference(pdmodel)) == 0,
104+
"The while op is pruned by mistake.")
105+
106+
def test_for_op(self):
107+
paddle.disable_static()
108+
net = ForNet()
109+
net = paddle.jit.to_static(
110+
net,
111+
input_spec=[paddle.static.InputSpec(
112+
shape=[1], dtype='int32')])
113+
paddle.jit.save(net, './for_net')
114+
115+
right_pdmodel = set([
116+
"randint", "fill_constant", "cast", "less_than", "while",
117+
"elementwise_add"
118+
])
119+
paddle.enable_static()
120+
pdmodel = getModelOp("for_net.pdmodel")
121+
#print(len(right_pdmodel.difference(pdmodel)))
122+
self.assertTrue(
123+
len(right_pdmodel.difference(pdmodel)) == 0,
124+
"The for op is pruned by mistake.")
125+
126+
def test_if_op(self):
127+
paddle.disable_static()
128+
net = IfElseNet()
129+
net = paddle.jit.to_static(
130+
net,
131+
input_spec=[paddle.static.InputSpec(
132+
shape=[1], dtype='int32')])
133+
paddle.jit.save(net, './if_net')
134+
135+
right_pdmodel = set([
136+
"assign_value", "greater_than", "cast", "conditional_block",
137+
"logical_not", "select_input"
138+
])
139+
paddle.enable_static()
140+
pdmodel = getModelOp("if_net.pdmodel")
141+
#print(len(right_pdmodel.difference(pdmodel)))
142+
self.assertTrue(
143+
len(right_pdmodel.difference(pdmodel)) == 0,
144+
"The if op is pruned by mistake.")
145+
146+
147+
if __name__ == '__main__':
148+
unittest.main()

0 commit comments

Comments
 (0)