Skip to content

Commit 3c9dd7b

Browse files
authored
1.add compare result, 2.add seed for paddlegan (#493)
* 1.add compare result, 2.add seed for paddlegan * 1.add compare result, 2.add seed for paddlegan * 1.add compare result, 2.add seed for paddlegan * 1.add compare result, 2.add seed for paddlegan
1 parent 3d4fa14 commit 3c9dd7b

22 files changed

+206
-58
lines changed

ppgan/datasets/firstorder_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'):
123123
except FileExistsError:
124124
pass
125125
for idx, img in enumerate(video_array_reshape):
126-
cv2.imwrite(sub_dir.joinpath('%i.png' % idx), img)
126+
cv2.imwrite(str(sub_dir.joinpath('%i.png' % idx)), img[:,:,[2,1,0]])
127127
name.unlink()
128128
return video_array_reshape
129129
else:
@@ -207,7 +207,6 @@ def __getitem__(self, idx):
207207
num_frames, replace=True,
208208
size=2)) if self.is_train else range(num_frames)
209209
video_array = [video_array[i] for i in frame_idx]
210-
211210
# convert to 3-channel image
212211
if video_array[0].shape[-1] == 4:
213212
video_array = [i[..., :3] for i in video_array]
@@ -218,7 +217,6 @@ def __getitem__(self, idx):
218217
np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video_array
219218
]
220219
out = {}
221-
222220
if self.is_train:
223221
if self.transform is not None: #modify
224222
t = self.transform(tuple(video_array))

ppgan/utils/options.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def parse_args():
6767
default=None,
6868
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
6969
)
70+
# fix random numbers by setting seed
71+
parser.add_argument('--seed',
72+
type=int,
73+
default=None,
74+
help='fix random numbers by setting seed\".'
75+
)
7076
args = parser.parse_args()
7177

7278
return args

ppgan/utils/setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import os
1616
import time
1717
import paddle
18-
18+
import numpy as np
19+
import random
1920
from .logger import setup_logger
2021

21-
2222
def setup(args, cfg):
2323
if args.evaluate_only:
2424
cfg.is_train = False
@@ -44,3 +44,10 @@ def setup(args, cfg):
4444
paddle.set_device('gpu')
4545
else:
4646
paddle.set_device('cpu')
47+
48+
if args.seed:
49+
paddle.seed(args.seed)
50+
random.seed(args.seed)
51+
np.random.seed(args.seed)
52+
paddle.framework.random._manual_program_seed(args.seed)
53+

test_tipc/compare_results.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy as np
2+
import os
3+
import subprocess
4+
import json
5+
import argparse
6+
import glob
7+
8+
9+
def init_args():
10+
parser = argparse.ArgumentParser()
11+
# params for testing assert allclose
12+
parser.add_argument("--atol", type=float, default=1e-3)
13+
parser.add_argument("--rtol", type=float, default=1e-3)
14+
parser.add_argument("--gt_file", type=str, default="")
15+
parser.add_argument("--log_file", type=str, default="")
16+
parser.add_argument("--precision", type=str, default="fp32")
17+
return parser
18+
19+
def parse_args():
20+
parser = init_args()
21+
return parser.parse_args()
22+
23+
def load_from_file(gt_file):
24+
if not os.path.exists(gt_file):
25+
raise ValueError("The log file {} does not exists!".format(gt_file))
26+
with open(gt_file, 'r') as f:
27+
data = f.readlines()
28+
f.close()
29+
parser_gt = {}
30+
for line in data:
31+
metric_name, result = line.strip("\n").split(":")
32+
parser_gt[metric_name] = float(result)
33+
return parser_gt
34+
35+
if __name__ == "__main__":
36+
# Usage:
37+
# python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/*.txt --log_file=./test_tipc/output/*/*.txt
38+
39+
args = parse_args()
40+
41+
gt_collection = load_from_file(args.gt_file)
42+
pre_collection = load_from_file(args.log_file)
43+
44+
for metric in pre_collection.keys():
45+
try:
46+
np.testing.assert_allclose(
47+
np.array(pre_collection[metric]), np.array(gt_collection[metric]), atol=args.atol, rtol=args.rtol)
48+
print(
49+
"Assert allclose passed! The results of {} are consistent!".
50+
format(metric))
51+
except Exception as E:
52+
print(E)
53+
raise ValueError(
54+
"The results of {} are inconsistent!".
55+
format(metric))

test_tipc/configs/basicvsr/train_infer_python.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:python3.7
44
gpu_list:0
55
##
66
auto_cast:null
7-
total_iters:lite_train_lite_infer=5|whole_train_whole_infer=200
7+
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=200
88
output_dir:./output/
99
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
1010
pretrained_model:null
@@ -13,7 +13,7 @@ train_infer_img_dir:./data/basicvsr_reds/test
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/basicvsr_reds.yaml -o dataset.train.dataset.num_clips=2
16+
norm_train:tools/main.py -c configs/basicvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
@@ -37,7 +37,7 @@ inference_dir:basicvsrmodel_generator
3737
train_model:./inference/basicvsr/basicvsrmodel_generator
3838
infer_export:null
3939
infer_quant:False
40-
inference:tools/inference.py --model_type basicvsr -c configs/basicvsr_reds.yaml -o dataset.test.num_clips=2 dataset.test.number_frames=6
40+
inference:tools/inference.py --model_type basicvsr -c configs/basicvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=6 --output_path test_tipc/output/
4141
--device:gpu
4242
null:null
4343
null:null

test_tipc/configs/cyclegan/train_infer_python.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:python3.7
44
gpu_list:0|0,1
55
##
66
auto_cast:null
7-
epochs:lite_train_lite_infer=5|whole_train_whole_infer=200
7+
epochs:lite_train_lite_infer=1|whole_train_whole_infer=200
88
output_dir:./output/
99
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
1010
pretrained_model:null
@@ -13,7 +13,7 @@ train_infer_img_dir:./data/horse2zebra/test
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/cyclegan_horse2zebra.yaml -o
16+
norm_train:tools/main.py -c configs/cyclegan_horse2zebra.yaml --seed 123 -o log_config.interval=10 snapshot_config.interval=1
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
@@ -37,7 +37,7 @@ inference_dir:cycleganmodel_netG_A
3737
train_model:./inference/cyclegan_horse2zebra/cycleganmodel_netG_A
3838
infer_export:null
3939
infer_quant:False
40-
inference:tools/inference.py --model_type cyclegan -c configs/cyclegan_horse2zebra.yaml
40+
inference:tools/inference.py --model_type cyclegan --seed 123 -c configs/cyclegan_horse2zebra.yaml --output_path test_tipc/output/
4141
--device:gpu
4242
null:null
4343
null:null

test_tipc/configs/fom/train_infer_python.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:python3.7
44
gpu_list:0
55
##
66
auto_cast:null
7-
epochs:lite_train_lite_infer=10|whole_train_whole_infer=100
7+
epochs:lite_train_lite_infer=1|whole_train_whole_infer=100
88
output_dir:./output/
99
dataset.train.batch_size:lite_train_lite_infer=8|whole_train_whole_infer=8
1010
pretrained_model:null
@@ -13,7 +13,7 @@ train_infer_img_dir:./data/firstorder_vox_256/test
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/firstorder_vox_256.yaml -o
16+
norm_train:tools/main.py -c configs/firstorder_vox_256.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=1 dataset.train.num_repeats=1 dataset.train.id_sampling=False
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
@@ -37,7 +37,7 @@ inference_dir:fom_dy2st
3737
train_model:./inference/fom_dy2st/
3838
infer_export:null
3939
infer_quant:False
40-
inference:tools/fom_infer.py --driving_path data/first_order/Voxceleb/test --output_path infer_output/fom
40+
inference:tools/fom_infer.py --driving_path data/first_order/Voxceleb/test --output_path test_tipc/output/fom/
4141
--device:gpu
4242
null:null
4343
null:null

test_tipc/configs/pix2pix/train_infer_python.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
===========================train_params===========================
22
model_name:pix2pix
33
python:python3.7
4-
gpu_list:0|0,1
4+
gpu_list:0
55
##
66
auto_cast:null
7-
epochs:lite_train_lite_infer=5|whole_train_whole_infer=200
7+
epochs:lite_train_lite_infer=10|whole_train_whole_infer=200
88
output_dir:./output/
99
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
1010
pretrained_model:null
@@ -13,7 +13,7 @@ train_infer_img_dir:./data/facades/test
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/pix2pix_facades.yaml -o
16+
norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=5
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
@@ -27,7 +27,7 @@ null:null
2727
===========================infer_params===========================
2828
--output_dir:./output/
2929
load:null
30-
norm_export:tools/export_model.py -c configs/pix2pix_facades.yaml --inputs_size="-1,3,-1,-1" --load
30+
norm_export:tools/export_model.py -c configs/pix2pix_facades.yaml --inputs_size="-1,3,-1,-1" --load
3131
quant_export:null
3232
fpgm_export:null
3333
distill_export:null
@@ -37,7 +37,7 @@ inference_dir:pix2pixmodel_netG
3737
train_model:./inference/pix2pix_facade/pix2pixmodel_netG
3838
infer_export:null
3939
infer_quant:False
40-
inference:tools/inference.py --model_type pix2pix -c configs/pix2pix_facades.yaml
40+
inference:tools/inference.py --model_type pix2pix --seed 123 -c configs/pix2pix_facades.yaml --output_path test_tipc/output/
4141
--device:cpu
4242
null:null
4343
null:null

test_tipc/configs/stylegan2/train_infer_python.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:python3.7
44
gpu_list:0
55
##
66
auto_cast:null
7-
total_iters::lite_train_lite_infer=10|whole_train_whole_infer=800
7+
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=800
88
output_dir:./output/
99
dataset.train.batch_size:lite_train_lite_infer=3|whole_train_whole_infer=3
1010
pretrained_model:null
@@ -13,7 +13,7 @@ train_infer_img_dir:null
1313
null:null
1414
##
1515
trainer:norm_train
16-
norm_train:tools/main.py -c configs/stylegan_v2_256_ffhq.yaml -o
16+
norm_train:tools/main.py -c configs/stylegan_v2_256_ffhq.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=10
1717
pact_train:null
1818
fpgm_train:null
1919
distill_train:null
@@ -37,7 +37,7 @@ inference_dir:stylegan2model_gen
3737
train_model:./inference/stylegan2/stylegan2model_gen
3838
infer_export:null
3939
infer_quant:False
40-
inference:tools/inference.py --model_type stylegan2 -c configs/stylegan_v2_256_ffhq.yaml
40+
inference:tools/inference.py --model_type stylegan2 --seed 123 -c configs/stylegan_v2_256_ffhq.yaml --output_path test_tipc/output/
4141
--device:gpu
4242
null:null
4343
null:null

test_tipc/docs/compare_right.png

50.9 KB
Loading

0 commit comments

Comments
 (0)