|
| 1 | +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# api: paddle.fluid.framework.Program._prune_with_input() |
| 16 | +# env: local |
| 17 | +# device: cpu |
| 18 | +# text:prune-with-input |
| 19 | + |
| 20 | +import paddle.fluid as fluid |
| 21 | +import paddle.fluid.optimizer as optimizer |
| 22 | +import numpy as np |
| 23 | + |
| 24 | +def sample_data(): |
| 25 | + res = [] |
| 26 | + for i in range(2): |
| 27 | + data = np.random.normal(size=(2,)) |
| 28 | + label = np.random.randint(2, size=(1,)) |
| 29 | + res.append((data, label)) |
| 30 | + return res |
| 31 | + |
| 32 | +x = fluid.layers.data(name='x', shape=[2], dtype='float32') |
| 33 | +label = fluid.layers.data(name="label", shape=[1], dtype="int64") |
| 34 | + |
| 35 | +# define net here |
| 36 | +y = fluid.layers.fc(input=[x], size=2, act="softmax") |
| 37 | +loss = fluid.layers.cross_entropy(input=y, label=label) |
| 38 | +loss = fluid.layers.mean(x=loss) |
| 39 | + |
| 40 | +sgd = fluid.optimizer.SGD(learning_rate=0.01) |
| 41 | +sgd.minimize(loss) |
| 42 | + |
| 43 | +with open("original_program", "w") as f: |
| 44 | + f.write(str(fluid.default_main_program())) |
| 45 | + |
| 46 | +pruned_program = fluid.default_main_program()._prune_with_input( |
| 47 | + feeded_var_names=[y.name, label.name], |
| 48 | + targets = [loss]) |
| 49 | + |
| 50 | +with open("pruned_program", "w") as f: |
| 51 | + f.write(str(pruned_program)) |
0 commit comments