@@ -38,12 +38,12 @@ def train(cfg: DictConfig):
38
38
train_dataloader_cfg = {
39
39
"dataset" : {
40
40
"name" : "MatDataset" ,
41
- "file_path" : "./VIV_Training_Neta100.mat" ,
41
+ "file_path" : cfg . VIV_DATA_PATH ,
42
42
"input_keys" : ("t_f" ,),
43
43
"label_keys" : ("eta" , "f" ),
44
44
"weight_dict" : {"eta" : 100 },
45
45
},
46
- "batch_size" : 100 ,
46
+ "batch_size" : cfg . TRAIN . batch_size ,
47
47
"sampler" : {
48
48
"name" : "BatchSampler" ,
49
49
"drop_last" : False ,
@@ -64,18 +64,18 @@ def train(cfg: DictConfig):
64
64
}
65
65
66
66
# set optimizer
67
- lr_scheduler = ppsci .optimizer .lr_scheduler .Step (** cfg .TRAIN .optimizer )()
67
+ lr_scheduler = ppsci .optimizer .lr_scheduler .Step (** cfg .TRAIN .lr_scheduler )()
68
68
optimizer = ppsci .optimizer .Adam (lr_scheduler )((model ,) + tuple (equation .values ()))
69
69
70
70
# set validator
71
71
valid_dataloader_cfg = {
72
72
"dataset" : {
73
73
"name" : "MatDataset" ,
74
- "file_path" : "./VIV_Training_Neta100.mat" ,
74
+ "file_path" : cfg . VIV_DATA_PATH ,
75
75
"input_keys" : ("t_f" ,),
76
76
"label_keys" : ("eta" , "f" ),
77
77
},
78
- "batch_size" : 32 ,
78
+ "batch_size" : cfg . EVAL . batch_size ,
79
79
"sampler" : {
80
80
"name" : "BatchSampler" ,
81
81
"drop_last" : False ,
@@ -93,7 +93,7 @@ def train(cfg: DictConfig):
93
93
94
94
# set visualizer(optional)
95
95
visu_mat = ppsci .utils .reader .load_mat_file (
96
- "./VIV_Training_Neta100.mat" ,
96
+ cfg . VIV_DATA_PATH ,
97
97
("t_f" , "eta_gt" , "f_gt" ),
98
98
alias_dict = {"eta_gt" : "eta" , "f_gt" : "f" },
99
99
)
@@ -121,7 +121,7 @@ def train(cfg: DictConfig):
121
121
lr_scheduler ,
122
122
cfg .TRAIN .epochs ,
123
123
cfg .TRAIN .iters_per_epoch ,
124
- eval_during_train = True ,
124
+ eval_during_train = cfg . TRAIN . eval_during_train ,
125
125
eval_freq = cfg .TRAIN .eval_freq ,
126
126
equation = equation ,
127
127
validator = validator ,
@@ -153,11 +153,11 @@ def evaluate(cfg: DictConfig):
153
153
valid_dataloader_cfg = {
154
154
"dataset" : {
155
155
"name" : "MatDataset" ,
156
- "file_path" : "./VIV_Training_Neta100.mat" ,
156
+ "file_path" : cfg . VIV_DATA_PATH ,
157
157
"input_keys" : ("t_f" ,),
158
158
"label_keys" : ("eta" , "f" ),
159
159
},
160
- "batch_size" : 32 ,
160
+ "batch_size" : cfg . EVAL . batch_size ,
161
161
"sampler" : {
162
162
"name" : "BatchSampler" ,
163
163
"drop_last" : False ,
@@ -175,7 +175,7 @@ def evaluate(cfg: DictConfig):
175
175
176
176
# set visualizer(optional)
177
177
visu_mat = ppsci .utils .reader .load_mat_file (
178
- "./VIV_Training_Neta100.mat" ,
178
+ cfg . VIV_DATA_PATH ,
179
179
("t_f" , "eta_gt" , "f_gt" ),
180
180
alias_dict = {"eta_gt" : "eta" , "f_gt" : "f" },
181
181
)
@@ -205,9 +205,9 @@ def evaluate(cfg: DictConfig):
205
205
pretrained_model_path = cfg .EVAL .pretrained_model_path ,
206
206
)
207
207
208
- # evaluate after finished training
208
+ # evaluate
209
209
solver .eval ()
210
- # visualize prediction after finished training
210
+ # visualize prediction
211
211
solver .visualize ()
212
212
213
213
0 commit comments