File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -29,14 +29,14 @@ def main(cfg):
29
29
30
30
# Load Dataset
31
31
data_cfg = cfg ['data' ]
32
- test_ds = load_dataset (** data_cfg )
32
+ test_ds = load_dataset (data_cfg )
33
33
test_dl = torch .utils .data .DataLoader (test_ds ,
34
34
batch_size = hp_cfg ['batch_size' ])
35
35
print (f"Load Dataset { data_cfg ['dataset' ]} " )
36
36
37
37
# Load Model
38
38
model_cfg = cfg ['model' ]
39
- model = load_model (** model_cfg ).to (device )
39
+ model = load_model (model_cfg ).to (device )
40
40
ckpt = torch .load (os .path .join (cfg ['save_dir' ], cfg ['weights_file_name' ]),
41
41
map_location = device , weights_only = False )
42
42
model .load_state_dict (ckpt ['model' ])
@@ -53,7 +53,7 @@ def main(cfg):
53
53
parser = argparse .ArgumentParser ('Test' , parents = [add_args_parser ()])
54
54
args = parser .parse_args ()
55
55
56
- with open (f'configs/test. { args .config } .yaml' ) as f :
56
+ with open (f'configs/test/ { args .config } .yaml' ) as f :
57
57
cfg = yaml .full_load (f )
58
58
59
59
main (cfg )
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def main(cfg):
30
30
31
31
# Load Dataset
32
32
data_cfg = cfg ['data' ]
33
- train_ds = load_dataset (** data_cfg )
33
+ train_ds = load_dataset (data_cfg )
34
34
train_dl = torch .utils .data .DataLoader (train_ds ,
35
35
shuffle = True ,
36
36
batch_size = hp_cfg ['batch_size' ],
@@ -40,7 +40,7 @@ def main(cfg):
40
40
# Load Model
41
41
model_cfg = cfg ['model' ]
42
42
print (model_cfg ['name' ])
43
- model = load_model (** model_cfg ).to (device )
43
+ model = load_model (model_cfg ).to (device )
44
44
if cfg ['parallel' ] == True :
45
45
model = nn .DataParallel (model )
46
46
@@ -94,7 +94,7 @@ def main(cfg):
94
94
parser = argparse .ArgumentParser ('Training' , parents = [add_args_parser ()])
95
95
args = parser .parse_args ()
96
96
97
- with open (f'configs/train. { args .config } .yaml' ) as f :
97
+ with open (f'configs/train/ { args .config } .yaml' ) as f :
98
98
cfg = yaml .full_load (f )
99
99
100
100
main (cfg )
You can’t perform that action at this time.
0 commit comments