Skip to content

Commit b09698b

Browse files
authored
* add ofa
1 parent 8854b71 commit b09698b

File tree

8 files changed

+2086
-0
lines changed

8 files changed

+2086
-0
lines changed

demo/one_shot/ofa_train.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
import paddle.fluid as fluid
18+
import paddle.fluid.dygraph.nn as nn
19+
from paddle.nn import ReLU
20+
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
21+
from paddleslim.nas.ofa import supernet
22+
23+
24+
class Model(fluid.dygraph.Layer):
25+
def __init__(self):
26+
super(Model, self).__init__()
27+
with supernet(
28+
kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) as ofa_super:
29+
models = []
30+
models += [nn.Conv2D(1, 6, 3)]
31+
models += [ReLU()]
32+
models += [nn.Pool2D(2, 'max', 2)]
33+
models += [nn.Conv2D(6, 16, 5, padding=0)]
34+
models += [ReLU()]
35+
models += [nn.Pool2D(2, 'max', 2)]
36+
models += [
37+
nn.Linear(784, 120), nn.Linear(120, 84), nn.Linear(84, 10)
38+
]
39+
models = ofa_super.convert(models)
40+
self.models = paddle.nn.Sequential(*models)
41+
42+
def forward(self, inputs, label, depth=None):
43+
if depth != None:
44+
assert isinstance(depth, int)
45+
assert depth < len(self.models)
46+
models = self.models[:depth]
47+
else:
48+
depth = len(self.models)
49+
models = self.models[:]
50+
51+
for idx, layer in enumerate(models):
52+
if idx == 6:
53+
inputs = fluid.layers.flatten(inputs, 1)
54+
inputs = layer(inputs)
55+
56+
inputs = fluid.layers.softmax(inputs)
57+
return inputs
58+
59+
60+
def test_ofa():
61+
62+
default_run_config = {
63+
'train_batch_size': 256,
64+
'eval_batch_size': 64,
65+
'n_epochs': [[1], [2, 3], [4, 5]],
66+
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
67+
'dynamic_batch_size': [1, 1, 1],
68+
'total_images': 50000, #1281167,
69+
'elastic_depth': (2, 5, 8)
70+
}
71+
run_config = RunConfig(**default_run_config)
72+
73+
default_distill_config = {
74+
'lambda_distill': 0.01,
75+
'teacher_model': Model,
76+
'mapping_layers': ['models.0.fn']
77+
}
78+
distill_config = DistillConfig(**default_distill_config)
79+
80+
fluid.enable_dygraph()
81+
model = Model()
82+
ofa_model = OFA(model, run_config, distill_config=distill_config)
83+
84+
train_reader = paddle.fluid.io.batch(
85+
paddle.dataset.mnist.train(), batch_size=256, drop_last=True)
86+
87+
start_epoch = 0
88+
for idx in range(len(run_config.n_epochs)):
89+
cur_idx = run_config.n_epochs[idx]
90+
for ph_idx in range(len(cur_idx)):
91+
cur_lr = run_config.init_learning_rate[idx][ph_idx]
92+
adam = fluid.optimizer.Adam(
93+
learning_rate=cur_lr,
94+
parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
95+
for epoch_id in range(start_epoch,
96+
run_config.n_epochs[idx][ph_idx]):
97+
for batch_id, data in enumerate(train_reader()):
98+
dy_x_data = np.array(
99+
[x[0].reshape(1, 28, 28)
100+
for x in data]).astype('float32')
101+
y_data = np.array(
102+
[x[1] for x in data]).astype('int64').reshape(-1, 1)
103+
104+
img = fluid.dygraph.to_variable(dy_x_data)
105+
label = fluid.dygraph.to_variable(y_data)
106+
label.stop_gradient = True
107+
108+
for model_no in range(run_config.dynamic_batch_size[idx]):
109+
output, _ = ofa_model(img, label)
110+
loss = fluid.layers.reduce_mean(output)
111+
dis_loss = ofa_model.calc_distill_loss()
112+
loss += dis_loss
113+
loss.backward()
114+
115+
if batch_id % 10 == 0:
116+
print(
117+
'epoch: {}, batch: {}, loss: {}, distill loss: {}'.
118+
format(epoch_id, batch_id,
119+
loss.numpy()[0], dis_loss.numpy()[0]))
120+
### accumurate dynamic_batch_size network of gradients for same batch of data
121+
### NOTE: need to fix gradients accumulate in PaddlePaddle
122+
adam.minimize(loss)
123+
adam.clear_gradients()
124+
start_epoch = run_config.n_epochs[idx][ph_idx]
125+
126+
127+
test_ofa()

paddleslim/nas/ofa/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .ofa import OFA, RunConfig, DistillConfig
16+
from .convert_super import supernet
17+
from .layers import *

0 commit comments

Comments
 (0)