Skip to content

Commit 60d2e81

Browse files
committed
[Update] load config path
1 parent 3c86694 commit 60d2e81

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def main(cfg):
2929

3030
# Load Dataset
3131
data_cfg = cfg['data']
32-
test_ds = load_dataset(**data_cfg)
32+
test_ds = load_dataset(data_cfg)
3333
test_dl = torch.utils.data.DataLoader(test_ds,
3434
batch_size=hp_cfg['batch_size'])
3535
print(f"Load Dataset {data_cfg['dataset']}")
3636

3737
# Load Model
3838
model_cfg = cfg['model']
39-
model = load_model(**model_cfg).to(device)
39+
model = load_model(model_cfg).to(device)
4040
ckpt = torch.load(os.path.join(cfg['save_dir'], cfg['weights_file_name']),
4141
map_location=device, weights_only=False)
4242
model.load_state_dict(ckpt['model'])
@@ -53,7 +53,7 @@ def main(cfg):
5353
parser = argparse.ArgumentParser('Test', parents=[add_args_parser()])
5454
args = parser.parse_args()
5555

56-
with open(f'configs/test.{args.config}.yaml') as f:
56+
with open(f'configs/test/{args.config}.yaml') as f:
5757
cfg = yaml.full_load(f)
5858

5959
main(cfg)

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main(cfg):
3030

3131
# Load Dataset
3232
data_cfg = cfg['data']
33-
train_ds = load_dataset(**data_cfg)
33+
train_ds = load_dataset(data_cfg)
3434
train_dl = torch.utils.data.DataLoader(train_ds,
3535
shuffle=True,
3636
batch_size=hp_cfg['batch_size'],
@@ -40,7 +40,7 @@ def main(cfg):
4040
# Load Model
4141
model_cfg = cfg['model']
4242
print(model_cfg['name'])
43-
model = load_model(**model_cfg).to(device)
43+
model = load_model(model_cfg).to(device)
4444
if cfg['parallel'] == True:
4545
model = nn.DataParallel(model)
4646

@@ -94,7 +94,7 @@ def main(cfg):
9494
parser = argparse.ArgumentParser('Training', parents=[add_args_parser()])
9595
args = parser.parse_args()
9696

97-
with open(f'configs/train.{args.config}.yaml') as f:
97+
with open(f'configs/train/{args.config}.yaml') as f:
9898
cfg = yaml.full_load(f)
9999

100100
main(cfg)

0 commit comments

Comments
 (0)