Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 7091f07

Browse files
Merge pull request #1 from mapingshuo/prune_with_input
add prune_with_input example
2 parents ec22004 + f26d0d1 commit 7091f07

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# api: paddle.fluid.framework.Program._prune_with_input()
16+
# env: local
17+
# device: cpu
18+
# text:prune-with-input
19+
20+
import paddle.fluid as fluid
21+
import paddle.fluid.optimizer as optimizer
22+
import numpy as np
23+
24+
def sample_data():
25+
res = []
26+
for i in range(2):
27+
data = np.random.normal(size=(2,))
28+
label = np.random.randint(2, size=(1,))
29+
res.append((data, label))
30+
return res
31+
32+
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
33+
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
34+
35+
# define net here
36+
y = fluid.layers.fc(input=[x], size=2, act="softmax")
37+
loss = fluid.layers.cross_entropy(input=y, label=label)
38+
loss = fluid.layers.mean(x=loss)
39+
40+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
41+
sgd.minimize(loss)
42+
43+
with open("original_program", "w") as f:
44+
f.write(str(fluid.default_main_program()))
45+
46+
pruned_program = fluid.default_main_program()._prune_with_input(
47+
feeded_var_names=[y.name, label.name],
48+
targets = [loss])
49+
50+
with open("pruned_program", "w") as f:
51+
f.write(str(pruned_program))

0 commit comments

Comments
 (0)