Skip to content

Commit 57c712b

Browse files
committed
Add export and inference for amgnet
1 parent e0f0245 commit 57c712b

File tree

4 files changed

+245
-2
lines changed

4 files changed

+245
-2
lines changed

docs/zh/examples/amgnet.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@
5656
python amgnet_cylinder.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_cylinder_pretrained.pdparams
5757
```
5858

59+
=== "模型导出命令"
60+
61+
=== "amgnet_cylinder"
62+
63+
``` sh
64+
python amgnet_cylinder.py mode=export
65+
```
66+
67+
=== "Python推理命令"
68+
69+
=== "amgnet_cylinder"
70+
71+
``` sh
72+
python amgnet_cylinder.py mode=infer
73+
```
74+
5975
| 预训练模型 | 指标 |
6076
|:--| :--|
6177
| [amgnet_airfoil_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_airfoil_pretrained.pdparams) | loss(RMSE_validator): 0.0001 <br> RMSE.RMSE(RMSE_validator): 0.01315 |
@@ -291,6 +307,64 @@ unzip data.zip
291307
--8<--
292308
```
293309

310+
### 3.9 模型导出与推理
311+
312+
训练完成后,我们可以将模型导出为静态图格式,并使用Python推理引擎进行部署。
313+
314+
#### 3.9.1 导出模型
315+
316+
我们首先需要在 `amgnet_cylinder.py` 中实现 `export` 函数,它负责加载训练好的模型,并将其保存为推理所需的格式。
317+
318+
``` py linenums="235"
319+
--8<--
320+
examples/amgnet/amgnet_cylinder.py:235:256
321+
--8<--
322+
```
323+
324+
通过运行以下命令,即可执行导出:
325+
326+
```bash
327+
python amgnet_cylinder.py mode=export
328+
```
329+
330+
导出的模型将包含 `amgnet_cylinder.pdmodel` (模型结构) 和 `amgnet_cylinder.pdiparams` (模型权重) 文件,保存在配置文件 `INFER.export_path` 所指定的目录中。
331+
332+
#### 3.9.2 创建推理器
333+
334+
为了执行推理,我们创建了一个专用的 `AMGNPredictor` 类,存放于 `deploy/python_infer/amgn_predictor.py`。这个类继承自 `ppsci.deploy.base_predictor.Predictor`,并实现了加载模型和执行预测的核心逻辑。
335+
336+
``` py linenums="28"
337+
--8<--
338+
examples/amgnet/deploy/python_infer/amgn_predictor.py:28:87
339+
--8<--
340+
```
341+
342+
#### 3.9.3 执行推理
343+
344+
最后,我们实现 `inference` 函数。该函数会实例化 `AMGNPredictor`,加载数据,并循环执行预测,最后将结果可视化。
345+
346+
``` py linenums="259"
347+
--8<--
348+
examples/amgnet/amgnet_cylinder.py:259:298
349+
--8<--
350+
```
351+
352+
通过以下命令来运行推理:
353+
354+
```bash
355+
python amgnet_cylinder.py mode=infer
356+
```
357+
358+
#### 3.9.4 新增配置
359+
360+
为了支持以上功能,需要在 `conf/amgnet_cylinder.yaml` 中添加 `INFER` 配置项。
361+
362+
``` yaml linenums="65"
363+
--8<--
364+
examples/amgnet/conf/amgnet_cylinder.yaml:65:68
365+
--8<--
366+
```
367+
294368
## 4. 完整代码
295369

296370
=== "airfoil"

examples/amgnet/amgn_predictor.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
from typing import Dict
19+
20+
import numpy as np
21+
22+
from ppsci.deploy.base_predictor import Predictor
23+
from ppsci.utils import logger
24+
25+
if TYPE_CHECKING:
26+
import pgl
27+
from omegaconf import DictConfig
28+
29+
30+
class AMGNPredictor(Predictor):
31+
"""Predictor for AMGNet model.
32+
33+
Args:
34+
cfg (DictConfig): Configuration object.
35+
"""
36+
37+
def __init__(self, cfg: DictConfig):
38+
super().__init__(cfg)
39+
40+
def predict(
41+
self,
42+
input_dict: Dict[str, "pgl.Graph"],
43+
batch_size: int = 64,
44+
) -> Dict[str, np.ndarray]:
45+
"""Predicts the output of the model for a given input.
46+
47+
Args:
48+
input_dict (Dict[str, "pgl.Graph"]): Input data in a dictionary.
49+
batch_size (int, optional): Batch size for prediction. Defaults to 64.
50+
51+
Returns:
52+
Dict[str, np.ndarray]: Predicted output in a dictionary.
53+
"""
54+
# Note: amgnet only supports batch_size=1
55+
if batch_size > 1:
56+
logger.warning(
57+
f"AMGNet predictor only support batch_size=1, but got {batch_size}. "
58+
"Automatically set batch_size to 1."
59+
)
60+
batch_size = 1
61+
62+
output_dict = {}
63+
for key, graph in input_dict.items():
64+
input_names = self.predictor.get_input_names()
65+
for name in input_names:
66+
handle = self.predictor.get_input_handle(name)
67+
data = getattr(graph, name)
68+
handle.copy_from_cpu(data)
69+
70+
self.predictor.run()
71+
output_names = self.predictor.get_output_names()
72+
for name in output_names:
73+
handle = self.predictor.get_output_handle(name)
74+
output = handle.copy_to_cpu()
75+
output_dict[name] = output
76+
77+
# mapping data to cfg.INFER.output_keys
78+
output_dict = {
79+
store_key: output_dict[infer_key]
80+
for store_key, infer_key in zip(
81+
self.output_keys, self.predictor.get_output_names()
82+
)
83+
}
84+
return output_dict

examples/amgnet/amgnet_cylinder.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import utils
2424
from omegaconf import DictConfig
2525
from paddle.nn import functional as F
26+
from paddle.static import InputSpec
2627

2728
import ppsci
2829
from ppsci.utils import logger
@@ -212,14 +213,91 @@ def evaluate(cfg: DictConfig):
212213
)
213214

214215

216+
def export(cfg: DictConfig):
217+
"""Export the model for inference."""
218+
# initialize logger
219+
logger.init_logger("ppsci", osp.join(cfg.output_dir, "export.log"), "info")
220+
221+
# set model
222+
model = ppsci.arch.AMGNet(**cfg.MODEL)
223+
224+
# initialize solver
225+
solver = ppsci.solver.Solver(
226+
model,
227+
pretrained_model_path=cfg.INFER.pretrained_model_path,
228+
)
229+
230+
# export model
231+
input_spec = [
232+
{
233+
"input": {
234+
"node_feat": InputSpec(
235+
[None, cfg.MODEL.input_dim], "float32", name="node_feat"
236+
),
237+
"edge_feat": InputSpec([None, 2], "float32", name="edge_feat"),
238+
}
239+
},
240+
]
241+
solver.export(input_spec, cfg.INFER.export_path, skip_prune=True)
242+
243+
244+
def inference(cfg: DictConfig):
245+
"""Run inference with the exported model."""
246+
from deploy.python_infer import amgn_predictor
247+
248+
# initialize logger
249+
logger.init_logger("ppsci", osp.join(cfg.output_dir, "infer.log"), "info")
250+
251+
# set model predictor
252+
predictor = amgn_predictor.AMGNPredictor(cfg)
253+
254+
# set dataloader
255+
eval_dataloader_cfg = {
256+
"dataset": {
257+
"name": "MeshCylinderDataset",
258+
"input_keys": ("input",),
259+
"label_keys": ("label",),
260+
"data_dir": cfg.EVAL_DATA_DIR,
261+
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
262+
},
263+
"batch_size": 1,
264+
"sampler": {
265+
"name": "BatchSampler",
266+
"drop_last": False,
267+
"shuffle": False,
268+
},
269+
}
270+
eval_dataloader = ppsci.data.build_dataloader(**eval_dataloader_cfg)
271+
272+
# run inference
273+
logger.message("Now running inference, please wait...")
274+
for index, (input_, label, _) in enumerate(eval_dataloader):
275+
output_dict = predictor.predict(input_, cfg.INFER.batch_size)
276+
truefield = label["label"].y
277+
utils.log_images(
278+
input_["input"].pos,
279+
output_dict["pred"],
280+
truefield,
281+
eval_dataloader.dataset.elems_list,
282+
index,
283+
"cylinder_infer",
284+
)
285+
286+
215287
@hydra.main(version_base=None, config_path="./conf", config_name="amgnet_cylinder.yaml")
216288
def main(cfg: DictConfig):
217289
if cfg.mode == "train":
218290
train(cfg)
219291
elif cfg.mode == "eval":
220292
evaluate(cfg)
293+
elif cfg.mode == "export":
294+
export(cfg)
295+
elif cfg.mode == "infer":
296+
inference(cfg)
221297
else:
222-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
298+
raise ValueError(
299+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
300+
)
223301

224302

225303
if __name__ == "__main__":

examples/amgnet/conf/amgnet_cylinder.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,12 @@ TRAIN:
6363
# evaluation settings
6464
EVAL:
6565
batch_size: 1
66-
pretrained_model_path: null
66+
# NOTE: The following path is a placeholder, please replace it with your own model path
67+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_cylinder_pretrained.pdparams
6768
eval_with_no_grad: true
69+
70+
# inference settings
71+
INFER:
72+
batch_size: 1
73+
pretrained_model_path: ${EVAL.pretrained_model_path}
74+
export_path: ./inference/amgnet_cylinder

0 commit comments

Comments
 (0)