Skip to content

Commit eebf94d

Browse files
authored
[TIPC] add tipc benchmark for msvsr (#672)
* add tipc benchmark for msvsr * update tipc readme img
1 parent 91dcc90 commit eebf94d

12 files changed

+211
-66
lines changed

ppgan/engine/trainer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
class IterLoader:
35+
3536
def __init__(self, dataloader):
3637
self._dataloader = dataloader
3738
self.iter_loader = iter(self._dataloader)
@@ -79,6 +80,7 @@ class Trainer:
7980
# | ||
8081
# save checkpoint (model.nets) \/
8182
"""
83+
8284
def __init__(self, cfg):
8385
# base config
8486
self.logger = logging.getLogger(__name__)
@@ -181,6 +183,22 @@ def train(self):
181183

182184
iter_loader = IterLoader(self.train_dataloader)
183185

186+
# use amp
187+
if self.cfg.amp:
188+
self.logger.info('use AMP to train. AMP level = {}'.format(
189+
self.cfg.amp_level))
190+
assert self.cfg.model.name == 'MultiStageVSRModel', "AMP only support msvsr model"
191+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
192+
# need to decorate model and optim if amp_level == 'O2'
193+
if self.cfg.amp_level == 'O2':
194+
# msvsr has only one generator and one optimizer
195+
self.model.nets['generator'], self.optimizers[
196+
'optim'] = paddle.amp.decorate(
197+
models=self.model.nets['generator'],
198+
optimizers=self.optimizers['optim'],
199+
level='O2',
200+
save_dtype='float32')
201+
184202
# set model.is_train = True
185203
self.model.setup_train_mode(is_train=True)
186204
while self.current_iter < (self.total_iters + 1):
@@ -195,7 +213,12 @@ def train(self):
195213
# unpack data from dataset and apply preprocessing
196214
# data input should be dict
197215
self.model.setup_input(data)
198-
self.model.train_iter(self.optimizers)
216+
217+
if self.cfg.amp:
218+
self.model.train_iter_amp(self.optimizers, scaler,
219+
self.cfg.amp_level) # amp train
220+
else:
221+
self.model.train_iter(self.optimizers) # norm train
199222

200223
batch_cost_averager.record(
201224
time.time() - step_start_time,

ppgan/models/msvsr_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MultiStageVSRModel(BaseSRModel):
3030
Paper:
3131
PP-MSVSR: Multi-Stage Video Super-Resolution, 2021
3232
"""
33+
3334
def __init__(self, generator, fix_iter, pixel_criterion=None):
3435
"""Initialize the PP-MSVSR class.
3536
@@ -96,6 +97,48 @@ def train_iter(self, optims=None):
9697

9798
self.current_iter += 1
9899

100+
# amp train with brute force implementation, maybe decorator can simplify this
101+
def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):
102+
optims['optim'].clear_grad()
103+
if self.fix_iter:
104+
if self.current_iter == 1:
105+
print('Train MSVSR with fixed spynet for', self.fix_iter,
106+
'iters.')
107+
for name, param in self.nets['generator'].named_parameters():
108+
if 'spynet' in name:
109+
param.trainable = False
110+
elif self.current_iter >= self.fix_iter + 1 and self.flag:
111+
print('Train all the parameters.')
112+
for name, param in self.nets['generator'].named_parameters():
113+
param.trainable = True
114+
if 'spynet' in name:
115+
param.optimize_attr['learning_rate'] = 0.25
116+
self.flag = False
117+
for net in self.nets.values():
118+
net.find_unused_parameters = False
119+
120+
# put loss computation in amp context
121+
with paddle.amp.auto_cast(enable=True, level=amp_level):
122+
output = self.nets['generator'](self.lq)
123+
if isinstance(output, (list, tuple)):
124+
out_stage2, output = output
125+
loss_pix_stage2 = self.pixel_criterion(out_stage2, self.gt)
126+
self.losses['loss_pix_stage2'] = loss_pix_stage2
127+
self.visual_items['output'] = output[:, 0, :, :, :]
128+
# pixel loss
129+
loss_pix = self.pixel_criterion(output, self.gt)
130+
self.losses['loss_pix'] = loss_pix
131+
132+
self.loss = sum(_value for _key, _value in self.losses.items()
133+
if 'loss_pix' in _key)
134+
scaled_loss = scaler.scale(self.loss)
135+
self.losses['loss'] = scaled_loss
136+
137+
scaled_loss.backward()
138+
scaler.minimize(optims['optim'], scaled_loss)
139+
140+
self.current_iter += 1
141+
99142
def test_iter(self, metrics=None):
100143
self.gt = self.gt.cpu()
101144
self.nets['generator'].eval()

ppgan/utils/options.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def parse_args():
4545
default=False,
4646
help='skip validation during training')
4747
# config options
48-
parser.add_argument("-o",
49-
"--opt",
50-
nargs='+',
48+
parser.add_argument("-o",
49+
"--opt",
50+
nargs='+',
5151
help="set configuration options")
5252

5353
#for inference
@@ -60,19 +60,31 @@ def parse_args():
6060
help="path to reference images")
6161
parser.add_argument("--model_path", default=None, help="model for loading")
6262

63-
# for profiler
64-
parser.add_argument('-p',
65-
'--profiler_options',
66-
type=str,
67-
default=None,
68-
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
63+
# for profiler
64+
parser.add_argument(
65+
'-p',
66+
'--profiler_options',
67+
type=str,
68+
default=None,
69+
help=
70+
'The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
6971
)
7072
# fix random numbers by setting seed
7173
parser.add_argument('--seed',
7274
type=int,
7375
default=None,
74-
help='fix random numbers by setting seed\".'
75-
)
76+
help='fix random numbers by setting seed\".')
77+
78+
# add for amp training
79+
parser.add_argument('--amp',
80+
action='store_true',
81+
default=False,
82+
help='whether to enable amp training')
83+
parser.add_argument('--amp_level',
84+
type=str,
85+
default='O1',
86+
choices=['O1', 'O2'],
87+
help='level of amp training; O2 represent pure fp16')
7688
args = parser.parse_args()
7789

7890
return args

ppgan/utils/setup.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020
from .logger import setup_logger
2121

22+
2223
def setup(args, cfg):
2324
if args.evaluate_only:
2425
cfg.is_train = False
@@ -44,10 +45,13 @@ def setup(args, cfg):
4445
paddle.set_device('gpu')
4546
else:
4647
paddle.set_device('cpu')
47-
48+
4849
if args.seed:
4950
paddle.seed(args.seed)
5051
random.seed(args.seed)
51-
np.random.seed(args.seed)
52+
np.random.seed(args.seed)
5253
paddle.framework.random._manual_program_seed(args.seed)
53-
54+
55+
# add amp and amp_level args into cfg
56+
cfg['amp'] = args.amp
57+
cfg['amp_level'] = args.amp_level

test_tipc/readme.md renamed to test_tipc/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ test_tipc/
5757

5858
### 测试流程
5959
使用本工具,可以测试不同功能的支持情况,以及预测结果是否对齐,测试流程如下:
60-
<div align="center">
61-
<img src="docs/test.png" width="800">
62-
</div>
60+
61+
![img](https://user-images.githubusercontent.com/79366697/185377097-a0f852a8-2d78-45ae-84ba-ae71b799d738.png)
6362

6463
1. 运行prepare.sh准备测试所需数据和模型;
6564
2. 运行要测试的功能对应的测试脚本`test_*.sh`,产出log,由log可以看到不同配置是否运行成功;
@@ -72,4 +71,4 @@ test_tipc/
7271
<a name="more"></a>
7372
#### 更多教程
7473
各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
75-
[test_train_inference_python 使用](docs/test_train_inference_python.md)
74+
- [test_train_inference_python 使用](docs/test_train_inference_python.md): 测试基于Python的模型训练、评估、推理等基本功能

test_tipc/benchmark_train.sh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ source test_tipc/common_func.sh
44
# set env
55
python=python
66
export model_branch=`git symbolic-ref HEAD 2>/dev/null | cut -d"/" -f 3`
7-
export model_commit=$(git log|head -n1|awk '{print $2}')
7+
export model_commit=$(git log|head -n1|awk '{print $2}')
88
export str_tmp=$(echo `pip list|grep paddlepaddle-gpu|awk -F ' ' '{print $2}'`)
99
export frame_version=${str_tmp%%.post*}
1010
export frame_commit=$(echo `${python} -c "import paddle;print(paddle.version.commit)"`)
1111

12-
# run benchmark sh
12+
# run benchmark sh
1313
# Usage:
1414
# bash run_benchmark_train.sh config.txt params
15-
# or
15+
# or
1616
# bash run_benchmark_train.sh config.txt
1717

1818
function func_parser_params(){
@@ -100,6 +100,7 @@ for _flag in ${flags_list[*]}; do
100100
done
101101

102102
# set log_name
103+
BENCHMARK_ROOT=./ # self-test only
103104
repo_name=$(get_repo_name )
104105
SAVE_LOG=${BENCHMARK_LOG_DIR:-$(pwd)} # */benchmark_log
105106
mkdir -p "${SAVE_LOG}/benchmark_log/"
@@ -149,11 +150,11 @@ else
149150
fi
150151

151152
IFS="|"
152-
for batch_size in ${batch_size_list[*]}; do
153+
for batch_size in ${batch_size_list[*]}; do
153154
for precision in ${fp_items_list[*]}; do
154155
for device_num in ${device_num_list[*]}; do
155156
# sed batchsize and precision
156-
#func_sed_params "$FILENAME" "${line_precision}" "$precision"
157+
func_sed_params "$FILENAME" "${line_precision}" "$precision"
157158
func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size"
158159
func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch"
159160
gpu_id=$(set_gpu_id $device_num)
@@ -162,7 +163,7 @@ for batch_size in ${batch_size_list[*]}; do
162163
log_path="$SAVE_LOG/profiling_log"
163164
mkdir -p $log_path
164165
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_profiling"
165-
func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id
166+
func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id
166167
# set profile_option params
167168
tmp=`sed -i "${line_profile}s/.*/${profile_option}/" "${FILENAME}"`
168169

@@ -214,7 +215,7 @@ for batch_size in ${batch_size_list[*]}; do
214215
mkdir -p $speed_log_path
215216
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_log"
216217
speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_speed"
217-
func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id
218+
func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id
218219
func_sed_params "$FILENAME" "${line_profile}" "null" # sed --profile_option as null
219220
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 "
220221
echo $cmd
@@ -244,4 +245,4 @@ for batch_size in ${batch_size_list[*]}; do
244245
fi
245246
done
246247
done
247-
done
248+
done
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
===========================train_params===========================
2+
model_name:msvsr
3+
python:python3.7
4+
gpu_list:0
5+
##
6+
auto_cast:null
7+
total_iters:lite_train_lite_infer=10|lite_train_whole_infer=10|whole_train_whole_infer=200
8+
output_dir:./output/
9+
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
10+
pretrained_model:null
11+
train_model_name:msvsr_reds*/*checkpoint.pdparams
12+
train_infer_img_dir:./data/msvsr_reds/test
13+
null:null
14+
##
15+
trainer:amp_train
16+
amp_train:tools/main.py --amp --amp_level O1 -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.num_frames=2
17+
pact_train:null
18+
fpgm_train:null
19+
distill_train:null
20+
null:null
21+
null:null
22+
##
23+
===========================eval_params===========================
24+
eval:null
25+
null:null
26+
##
27+
===========================infer_params===========================
28+
--output_dir:./output/
29+
load:null
30+
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load
31+
quant_export:null
32+
fpgm_export:null
33+
distill_export:null
34+
export1:null
35+
export2:null
36+
inference_dir:inference
37+
train_model:./inference/msvsr/multistagevsrmodel_generator
38+
infer_export:null
39+
infer_quant:False
40+
inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_frames=2 --output_path test_tipc/output/
41+
--device:cpu
42+
null:null
43+
null:null
44+
null:null
45+
null:null
46+
null:null
47+
--model_path:
48+
null:null
49+
null:null
50+
--benchmark:True
51+
null:null
52+
===========================infer_benchmark_params==========================
53+
random_infer_input:[{float32,[2,3,180,320]}]

test_tipc/configs/msvsr/train_infer_python.txt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@ train_infer_img_dir:./data/msvsr_reds/test
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.num_frames=2
16+
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.interval=2 snapshot_config.interval=50 dataset.train.dataset.num_frames=15
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
2020
null:null
2121
null:null
2222
##
23-
===========================eval_params===========================
23+
===========================eval_params===========================
2424
eval:null
2525
null:null
2626
##
2727
===========================infer_params===========================
2828
--output_dir:./output/
2929
load:null
30-
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load
31-
quant_export:null
30+
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load
31+
quant_export:null
3232
fpgm_export:null
3333
distill_export:null
3434
export1:null
@@ -49,5 +49,11 @@ null:null
4949
null:null
5050
--benchmark:True
5151
null:null
52+
===========================train_benchmark_params==========================
53+
batch_size:4
54+
fp_items:fp32
55+
total_iters:60
56+
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
57+
flags:null
5258
===========================infer_benchmark_params==========================
5359
random_infer_input:[{float32,[2,3,180,320]}]

0 commit comments

Comments
 (0)