Skip to content

Commit fad95be

Browse files
committed
【hydra No.4】Adapt VIV to hydra
1 parent 6f02602 commit fad95be

File tree

3 files changed

+180
-46
lines changed

3 files changed

+180
-46
lines changed

docs/zh/examples/viv.md

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
![VIV_1D_SpringDamper](https://paddle-org.bj.bcebos.com/paddlescience/docs/ViV/VIV_1D_SpringDamper.png)
1818

19+
20+
=== "模型训练命令"
21+
22+
``` sh
23+
python viv.py
24+
```
25+
1926
## 2. 问题定义
2027

2128
本问题涉及的控制方程涉及三个物理量:$λ_1$、$λ_2$ 和 $ρ$,分别表示自然阻尼、结构特性刚度和质量,控制方程定义如下所示:
@@ -41,9 +48,9 @@ $$
4148

4249
上式中 $g$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
4350

44-
``` py linenums="28"
51+
``` py linenums="31"
4552
--8<--
46-
examples/fsi/viv.py:28:29
53+
examples/fsi/viv.py:31:32
4754
--8<--
4855
```
4956

@@ -56,9 +63,9 @@ examples/fsi/viv.py:28:29
5663

5764
由于 VIV 使用的是 VIV 方程,因此可以直接使用 PaddleScience 内置的 `VIV`
5865

59-
``` py linenums="31"
66+
``` py linenums="34"
6067
--8<--
61-
examples/fsi/viv.py:31:32
68+
examples/fsi/viv.py:34:35
6269
--8<--
6370
```
6471

@@ -76,19 +83,19 @@ examples/fsi/viv.py:31:32
7683

7784
在定义约束之前,需要给监督约束指定文件路径等数据读取配置。
7885

79-
``` py linenums="34"
86+
``` py linenums="37"
8087
--8<--
81-
examples/fsi/viv.py:34:50
88+
examples/fsi/viv.py:37:52
8289
--8<--
8390
```
8491

8592
#### 3.4.1 监督约束
8693

8794
由于我们以监督学习方式进行训练,此处采用监督约束 `SupervisedConstraint`
8895

89-
``` py linenums="51"
96+
``` py linenums="54"
9097
--8<--
91-
examples/fsi/viv.py:51:57
98+
examples/fsi/viv.py:54:60
9299
--8<--
93100
```
94101

@@ -102,29 +109,29 @@ examples/fsi/viv.py:51:57
102109

103110
在监督约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问。
104111

105-
``` py linenums="58"
112+
``` py linenums="61"
106113
--8<--
107-
examples/fsi/viv.py:58:61
114+
examples/fsi/viv.py:61:64
108115
--8<--
109116
```
110117

111118
### 3.5 超参数设定
112119

113120
接下来我们需要指定训练轮数和学习率,此处我们按实验经验,使用十万轮训练轮数,并每隔1000个epochs评估一次模型精度。
114121

115-
``` py linenums="63"
122+
``` yaml linenums="39"
116123
--8<--
117-
examples/fsi/viv.py:63:65
124+
examples/fsi/conf/viv.yaml:39:40
118125
--8<--
119126
```
120127

121128
### 3.6 优化器构建
122129

123130
训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器和 `Step` 间隔衰减学习率。
124131

125-
``` py linenums="67"
132+
``` py linenums="66"
126133
--8<--
127-
examples/fsi/viv.py:67:71
134+
examples/fsi/viv.py:66:68
128135
--8<--
129136
```
130137

@@ -136,9 +143,9 @@ examples/fsi/viv.py:67:71
136143

137144
在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器。
138145

139-
``` py linenums="73"
146+
``` py linenums="70"
140147
--8<--
141-
examples/fsi/viv.py:73:95
148+
examples/fsi/viv.py:70:92
142149
--8<--
143150
```
144151

@@ -152,19 +159,19 @@ examples/fsi/viv.py:73:95
152159

153160
本文需要可视化的数据是 $t-\eta$ 和 $t-f$ 两组关系图,假设每个时刻 $t$ 的坐标是 $t_i$,则对应网络输出为 $\eta_i$,升力为 $f_i$,因此我们只需要将评估过程中产生的所有 $(t_i, \eta_i, f_i)$ 保存成图片即可。代码如下:
154161

155-
``` py linenums="97"
162+
``` py linenums="94"
156163
--8<--
157-
examples/fsi/viv.py:97:116
164+
examples/fsi/viv.py:94:113
158165
--8<--
159166
```
160167

161168
### 3.9 模型训练、评估与可视化
162169

163170
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估、可视化。
164171

165-
``` py linenums="118"
172+
``` py linenums="115"
166173
--8<--
167-
examples/fsi/viv.py:118:
174+
examples/fsi/viv.py:115:136
168175
--8<--
169176
```
170177

examples/fsi/conf/viv.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
hydra:
2+
run:
3+
# dynamic output directory according to running time and override name
4+
dir: outputs_VIV/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
5+
job:
6+
name: ${mode} # name of logfile
7+
chdir: false # keep current working direcotry unchaned
8+
config:
9+
override_dirname:
10+
exclude_keys:
11+
- TRAIN.checkpoint_path
12+
- TRAIN.pretrained_model_path
13+
- EVAL.pretrained_model_path
14+
- mode
15+
- output_dir
16+
- log_freq
17+
sweep:
18+
# output directory for multirun
19+
dir: ${hydra.run.dir}
20+
subdir: ./
21+
22+
# general settings
23+
mode: train # running mode: train/eval
24+
seed: 42
25+
output_dir: ${hydra:run.dir}
26+
log_freq: 20
27+
28+
29+
# model settings
30+
MODEL:
31+
input_keys: ["t_f"]
32+
output_keys: ["eta"]
33+
num_layers: 5
34+
hidden_size: 50
35+
activation: "tanh"
36+
37+
# training settings
38+
TRAIN:
39+
epochs: 100000
40+
eval_freq: 1000
41+
iters_per_epoch: 1
42+
save_freq: 1
43+
eval_during_train: true
44+
optimizer:
45+
epochs: 100000
46+
iters_per_epoch: 1
47+
learning_rate: 0.001
48+
step_size: 20000
49+
gamma: 0.9
50+
pretrained_model_path: null
51+
checkpoint_path: null
52+
53+
# evaluation settings
54+
EVAL:
55+
pretrained_model_path: null

examples/fsi/viv.py

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
17+
import hydra
18+
from omegaconf import DictConfig
19+
1520
import ppsci
16-
from ppsci.utils import config
1721
from ppsci.utils import logger
1822

19-
if __name__ == "__main__":
20-
args = config.parse_args()
23+
24+
def train(cfg: DictConfig):
2125
# set random seed for reproducibility
22-
ppsci.utils.misc.set_random_seed(42)
26+
ppsci.utils.misc.set_random_seed(cfg.seed)
27+
2328
# set output directory
24-
OUTPUT_DIR = "./output_viv" if args.output_dir is None else args.output_dir
25-
# initialize logger
26-
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
29+
logger.init_logger("ppsci", os.path.join(cfg.output_dir, "train.log"), "info")
2730

2831
# set model
29-
model = ppsci.arch.MLP(("t_f",), ("eta",), 5, 50, "tanh")
32+
model = ppsci.arch.MLP(**cfg.MODEL)
3033

3134
# set equation
3235
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
3336

3437
# set dataloader config
35-
ITERS_PER_EPOCH = 1
3638
train_dataloader_cfg = {
3739
"dataset": {
3840
"name": "MatDataset",
@@ -48,6 +50,7 @@
4850
"shuffle": True,
4951
},
5052
}
53+
5154
# set constraint
5255
sup_constraint = ppsci.constraint.SupervisedConstraint(
5356
train_dataloader_cfg,
@@ -60,14 +63,8 @@
6063
sup_constraint.name: sup_constraint,
6164
}
6265

63-
# set training hyper-parameters
64-
EPOCHS = 100000 if args.epochs is None else args.epochs
65-
EVAL_FREQ = 1000
66-
6766
# set optimizer
68-
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(
69-
EPOCHS, ITERS_PER_EPOCH, 0.001, step_size=20000, gamma=0.9
70-
)()
67+
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.optimizer)()
7168
optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,) + tuple(equation.values()))
7269

7370
# set validator
@@ -119,35 +116,110 @@
119116
solver = ppsci.solver.Solver(
120117
model,
121118
constraint,
122-
OUTPUT_DIR,
119+
cfg.output_dir,
123120
optimizer,
124121
lr_scheduler,
125-
EPOCHS,
126-
ITERS_PER_EPOCH,
122+
cfg.TRAIN.epochs,
123+
cfg.TRAIN.iters_per_epoch,
127124
eval_during_train=True,
128-
eval_freq=EVAL_FREQ,
125+
eval_freq=cfg.TRAIN.eval_freq,
129126
equation=equation,
130127
validator=validator,
131128
visualizer=visualizer,
132129
)
130+
133131
# train model
134132
solver.train()
135133
# evaluate after finished training
136134
solver.eval()
137135
# visualize prediction after finished training
138136
solver.visualize()
139137

140-
# directly evaluate model from pretrained_model_path(optional)
141-
logger.init_logger("ppsci", f"{OUTPUT_DIR}/eval.log", "info")
138+
139+
def evaluate(cfg: DictConfig):
140+
# set random seed for reproducibility
141+
ppsci.utils.misc.set_random_seed(cfg.seed)
142+
143+
# set output directory
144+
logger.init_logger("ppsci", os.path.join(cfg.output_dir, "eval.log"), "info")
145+
146+
# set model
147+
model = ppsci.arch.MLP(**cfg.MODEL)
148+
149+
# set equation
150+
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
151+
152+
# set validator
153+
valid_dataloader_cfg = {
154+
"dataset": {
155+
"name": "MatDataset",
156+
"file_path": "./VIV_Training_Neta100.mat",
157+
"input_keys": ("t_f",),
158+
"label_keys": ("eta", "f"),
159+
},
160+
"batch_size": 32,
161+
"sampler": {
162+
"name": "BatchSampler",
163+
"drop_last": False,
164+
"shuffle": False,
165+
},
166+
}
167+
eta_mse_validator = ppsci.validate.SupervisedValidator(
168+
valid_dataloader_cfg,
169+
ppsci.loss.MSELoss("mean"),
170+
{"eta": lambda out: out["eta"], **equation["VIV"].equations},
171+
metric={"MSE": ppsci.metric.MSE()},
172+
name="eta_mse",
173+
)
174+
validator = {eta_mse_validator.name: eta_mse_validator}
175+
176+
# set visualizer(optional)
177+
visu_mat = ppsci.utils.reader.load_mat_file(
178+
"./VIV_Training_Neta100.mat",
179+
("t_f", "eta_gt", "f_gt"),
180+
alias_dict={"eta_gt": "eta", "f_gt": "f"},
181+
)
182+
183+
visualizer = {
184+
"visualize_u": ppsci.visualize.VisualizerScatter1D(
185+
visu_mat,
186+
("t_f",),
187+
{
188+
r"$\eta$": lambda d: d["eta"], # plot with latex title
189+
r"$\eta_{gt}$": lambda d: d["eta_gt"], # plot with latex title
190+
r"$f$": equation["VIV"].equations["f"], # plot with latex title
191+
r"$f_{gt}$": lambda d: d["f_gt"], # plot with latex title
192+
},
193+
num_timestamps=1,
194+
prefix="viv_pred",
195+
)
196+
}
197+
198+
# initialize solver
142199
solver = ppsci.solver.Solver(
143200
model,
144-
constraint,
145-
OUTPUT_DIR,
201+
output_dir=cfg.output_dir,
146202
equation=equation,
147203
validator=validator,
148204
visualizer=visualizer,
149-
pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest",
205+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
150206
)
207+
208+
# evaluate after finished training
151209
solver.eval()
152-
# visualize prediction from pretrained_model_path(optional)
210+
# visualize prediction after finished training
153211
solver.visualize()
212+
213+
214+
@hydra.main(version_base=None, config_path="./conf", config_name="viv.yaml")
215+
def main(cfg: DictConfig):
216+
if cfg.mode == "train":
217+
train(cfg)
218+
elif cfg.mode == "eval":
219+
evaluate(cfg)
220+
else:
221+
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
222+
223+
224+
if __name__ == "__main__":
225+
main()

0 commit comments

Comments
 (0)