Skip to content

Commit 2c64117

Browse files
committed
fix bug
1 parent fad95be commit 2c64117

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

docs/zh/examples/viv.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ examples/fsi/viv.py:61:64
119119

120120
接下来我们需要指定训练轮数和学习率,此处我们按实验经验,使用十万轮训练轮数,并每隔1000个epochs评估一次模型精度。
121121

122-
``` yaml linenums="39"
122+
``` yaml linenums="40"
123123
--8<--
124-
examples/fsi/conf/viv.yaml:39:40
124+
examples/fsi/conf/viv.yaml:40:45
125125
--8<--
126126
```
127127

examples/fsi/conf/viv.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ seed: 42
2525
output_dir: ${hydra:run.dir}
2626
log_freq: 20
2727

28+
VIV_DATA_PATH: "./VIV_Training_Neta100.mat"
2829

2930
# model settings
3031
MODEL:
@@ -37,13 +38,14 @@ MODEL:
3738
# training settings
3839
TRAIN:
3940
epochs: 100000
40-
eval_freq: 1000
4141
iters_per_epoch: 1
4242
save_freq: 1
4343
eval_during_train: true
44-
optimizer:
45-
epochs: 100000
46-
iters_per_epoch: 1
44+
eval_freq: 1000
45+
batch_size: 100
46+
lr_scheduler:
47+
epochs: ${TRAIN.epochs}
48+
iters_per_epoch: ${TRAIN.iters_per_epoch}
4749
learning_rate: 0.001
4850
step_size: 20000
4951
gamma: 0.9
@@ -53,3 +55,4 @@ TRAIN:
5355
# evaluation settings
5456
EVAL:
5557
pretrained_model_path: null
58+
batch_size: 32

examples/fsi/viv.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def train(cfg: DictConfig):
3838
train_dataloader_cfg = {
3939
"dataset": {
4040
"name": "MatDataset",
41-
"file_path": "./VIV_Training_Neta100.mat",
41+
"file_path": cfg.VIV_DATA_PATH,
4242
"input_keys": ("t_f",),
4343
"label_keys": ("eta", "f"),
4444
"weight_dict": {"eta": 100},
4545
},
46-
"batch_size": 100,
46+
"batch_size": cfg.TRAIN.batch_size,
4747
"sampler": {
4848
"name": "BatchSampler",
4949
"drop_last": False,
@@ -64,18 +64,18 @@ def train(cfg: DictConfig):
6464
}
6565

6666
# set optimizer
67-
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.optimizer)()
67+
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)()
6868
optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,) + tuple(equation.values()))
6969

7070
# set validator
7171
valid_dataloader_cfg = {
7272
"dataset": {
7373
"name": "MatDataset",
74-
"file_path": "./VIV_Training_Neta100.mat",
74+
"file_path": cfg.VIV_DATA_PATH,
7575
"input_keys": ("t_f",),
7676
"label_keys": ("eta", "f"),
7777
},
78-
"batch_size": 32,
78+
"batch_size": cfg.EVAL.batch_size,
7979
"sampler": {
8080
"name": "BatchSampler",
8181
"drop_last": False,
@@ -93,7 +93,7 @@ def train(cfg: DictConfig):
9393

9494
# set visualizer(optional)
9595
visu_mat = ppsci.utils.reader.load_mat_file(
96-
"./VIV_Training_Neta100.mat",
96+
cfg.VIV_DATA_PATH,
9797
("t_f", "eta_gt", "f_gt"),
9898
alias_dict={"eta_gt": "eta", "f_gt": "f"},
9999
)
@@ -121,7 +121,7 @@ def train(cfg: DictConfig):
121121
lr_scheduler,
122122
cfg.TRAIN.epochs,
123123
cfg.TRAIN.iters_per_epoch,
124-
eval_during_train=True,
124+
eval_during_train=cfg.TRAIN.eval_during_train,
125125
eval_freq=cfg.TRAIN.eval_freq,
126126
equation=equation,
127127
validator=validator,
@@ -153,11 +153,11 @@ def evaluate(cfg: DictConfig):
153153
valid_dataloader_cfg = {
154154
"dataset": {
155155
"name": "MatDataset",
156-
"file_path": "./VIV_Training_Neta100.mat",
156+
"file_path": cfg.VIV_DATA_PATH,
157157
"input_keys": ("t_f",),
158158
"label_keys": ("eta", "f"),
159159
},
160-
"batch_size": 32,
160+
"batch_size": cfg.EVAL.batch_size,
161161
"sampler": {
162162
"name": "BatchSampler",
163163
"drop_last": False,
@@ -175,7 +175,7 @@ def evaluate(cfg: DictConfig):
175175

176176
# set visualizer(optional)
177177
visu_mat = ppsci.utils.reader.load_mat_file(
178-
"./VIV_Training_Neta100.mat",
178+
cfg.VIV_DATA_PATH,
179179
("t_f", "eta_gt", "f_gt"),
180180
alias_dict={"eta_gt": "eta", "f_gt": "f"},
181181
)
@@ -205,9 +205,9 @@ def evaluate(cfg: DictConfig):
205205
pretrained_model_path=cfg.EVAL.pretrained_model_path,
206206
)
207207

208-
# evaluate after finished training
208+
# evaluate
209209
solver.eval()
210-
# visualize prediction after finished training
210+
# visualize prediction
211211
solver.visualize()
212212

213213

0 commit comments

Comments
 (0)