-
Notifications
You must be signed in to change notification settings - Fork 206
【Hackathon 8th No.11】DrivAerNet++ 论文复现 #1062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
8c3438a
4ca8511
a3768f7
2a6d9dd
b9caca5
2d2d13f
e798771
7c98038
3b872d2
0e0acc0
657b0ab
79facfa
7125eaa
4714ac8
beb8e38
bb45ad0
917ea4e
9c0cc68
a47e529
a3f6a67
18c0be5
33c73bd
7a2a852
40c61d9
8d3f03b
212ce32
3bbe0cc
184722c
ca3a33e
ec454cd
41f9726
8d3c888
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
- LNO | ||
- TGCN | ||
- RegDGCNN | ||
- RegPointNet | ||
- IFMMLP | ||
show_root_heading: true | ||
heading_level: 3 |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
defaults: | ||
- ppsci_default | ||
- TRAIN: train_default | ||
- TRAIN/ema: ema_default | ||
- TRAIN/swa: swa_default | ||
- EVAL: eval_default | ||
- INFER: infer_default | ||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
dir: outputs_drivaernetplusplus/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
job: | ||
name: ${mode} | ||
chdir: false | ||
callbacks: | ||
init_callback: | ||
_target_: ppsci.utils.callbacks.InitCallback | ||
sweep: | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: eval | ||
seed: 1 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 100 | ||
|
||
# model settings | ||
MODEL: | ||
input_keys: ["vertices"] | ||
output_keys: ["cd_value"] | ||
weight_keys: ["weight_keys"] | ||
dropout: 0.0 | ||
emb_dims: 1024 | ||
channels: [6, 64, 128, 256, 512, 1024] | ||
linear_sizes: [128, 64, 32, 16] | ||
k: 40 | ||
output_channels: 1 | ||
|
||
# training settings | ||
TRAIN: | ||
iters_per_epoch: 5399 | ||
epochs: 200 | ||
num_points: 100000 | ||
num_workers: 32 | ||
eval_during_train: True | ||
train_ids_file: "train_design_ids.txt" | ||
eval_ids_file: "val_design_ids.txt" | ||
batch_size: 32 | ||
scheduler: | ||
mode: "min" | ||
patience: 20 | ||
factor: 0.1 | ||
verbose: True | ||
|
||
# evaluation settings | ||
EVAL: | ||
num_points: 100000 | ||
batch_size: 1 | ||
pretrained_model_path: "https://dataset.bj.bcebos.com/PaddleScience/DNNFluid-Car/DrivAer%2B%2B/DragPrediction_DrivAerNet_PointNet_r2_batchsize16_200epochs_100kpoints_tsne_NeurIPS_best_model.pdparams" | ||
eval_with_no_grad: True | ||
ids_file: "test_design_ids.txt" | ||
num_workers: 8 | ||
|
||
# optimizer settings | ||
optimizer: | ||
weight_decay: 0.0001 | ||
lr: 0.001 | ||
optimizer: 'adam' | ||
|
||
ARGS: | ||
# dataset settings | ||
dataset_path: 'data/DrivAerNetPlusPlus_Processed_Point_Clouds_100k_paddle' | ||
aero_coeff: 'data/DrivAerNetPlusPlus_Drag_8k.csv' | ||
subset_dir: 'data/subset_dir' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ARGS这个字段感觉有点多余,可以删除掉吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 数据集的字段还是要的,删除之后无法正常运行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
哦我的意思是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import warnings | ||
from functools import partial | ||
|
||
import hydra | ||
import paddle | ||
from omegaconf import DictConfig | ||
|
||
import ppsci | ||
|
||
|
||
def train(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.RegPointNet( | ||
input_keys=cfg.MODEL.input_keys, | ||
label_keys=cfg.MODEL.output_keys, | ||
weight_keys=cfg.MODEL.weight_keys, | ||
args=cfg.MODEL, | ||
) | ||
|
||
train_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetPlusPlusDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": (cfg.MODEL.input_keys), | ||
"label_keys": (cfg.MODEL.output_keys), | ||
"weight_keys": (cfg.MODEL.weight_keys), | ||
LilaKen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.TRAIN.train_ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.TRAIN.num_points, | ||
}, | ||
"batch_size": cfg.TRAIN.batch_size, | ||
"num_workers": cfg.TRAIN.num_workers, | ||
} | ||
|
||
drivaernetplusplus_constraint = ppsci.constraint.SupervisedConstraint( | ||
train_dataloader_cfg, | ||
ppsci.loss.MSELoss("mean"), | ||
name="DrivAerNetplusplus_constraint", | ||
) | ||
|
||
constraint = {drivaernetplusplus_constraint.name: drivaernetplusplus_constraint} | ||
|
||
valid_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetPlusPlusDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": (cfg.MODEL.input_keys), | ||
"label_keys": (cfg.MODEL.output_keys), | ||
"weight_keys": (cfg.MODEL.weight_keys), | ||
LilaKen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.TRAIN.eval_ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.TRAIN.num_points, | ||
}, | ||
"batch_size": cfg.TRAIN.batch_size, | ||
"num_workers": cfg.TRAIN.num_workers, | ||
} | ||
|
||
drivaernetplusplus_valid = ppsci.validate.SupervisedValidator( | ||
valid_dataloader_cfg, | ||
loss=ppsci.loss.MSELoss("mean"), | ||
metric={"MSE": ppsci.metric.MSE()}, | ||
name="DrivAerNetplusplus_valid", | ||
) | ||
|
||
validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid} | ||
|
||
# set optimizer | ||
lr_scheduler = ppsci.optimizer.lr_scheduler.ReduceOnPlateau( | ||
epochs=cfg.TRAIN.epochs, | ||
iters_per_epoch=( | ||
cfg.TRAIN.iters_per_epoch | ||
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size) | ||
+ 1 | ||
), | ||
learning_rate=cfg.optimizer.lr, | ||
mode=cfg.TRAIN.scheduler.mode, | ||
patience=cfg.TRAIN.scheduler.patience, | ||
factor=cfg.TRAIN.scheduler.factor, | ||
verbose=cfg.TRAIN.scheduler.verbose, | ||
)() | ||
|
||
optimizer = ( | ||
ppsci.optimizer.Adam(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)( | ||
model | ||
) | ||
if cfg.optimizer.optimizer == "adam" | ||
else ppsci.optimizer.SGD(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)( | ||
model | ||
) | ||
) | ||
|
||
# initialize solver | ||
solver = ppsci.solver.Solver( | ||
model=model, | ||
iters_per_epoch=( | ||
cfg.TRAIN.iters_per_epoch | ||
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size) | ||
+ 1 | ||
), | ||
constraint=constraint, | ||
output_dir=cfg.output_dir, | ||
optimizer=optimizer, | ||
lr_scheduler=lr_scheduler, | ||
epochs=cfg.TRAIN.epochs, | ||
validator=validator, | ||
eval_during_train=cfg.TRAIN.eval_during_train, | ||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||
) | ||
|
||
lr_scheduler.step = partial(lr_scheduler.step, metrics=solver.cur_metric) | ||
solver.lr_scheduler = lr_scheduler | ||
|
||
# train model | ||
solver.train() | ||
|
||
solver.eval() | ||
|
||
|
||
def evaluate(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.RegPointNet( | ||
input_keys=cfg.MODEL.input_keys, | ||
label_keys=cfg.MODEL.output_keys, | ||
weight_keys=cfg.MODEL.weight_keys, | ||
args=cfg.MODEL, | ||
) | ||
|
||
valid_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetPlusPlusDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": (cfg.MODEL.input_keys), | ||
"label_keys": (cfg.MODEL.output_keys), | ||
"weight_keys": (cfg.MODEL.weight_keys), | ||
LilaKen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.EVAL.ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.EVAL.num_points, | ||
}, | ||
"batch_size": cfg.EVAL.batch_size, | ||
"num_workers": cfg.EVAL.num_workers, | ||
} | ||
|
||
drivaernetplusplus_valid = ppsci.validate.SupervisedValidator( | ||
valid_dataloader_cfg, | ||
loss=ppsci.loss.MSELoss("mean"), | ||
metric={ | ||
"MSE": ppsci.metric.MSE(), | ||
"MAE": ppsci.metric.MAE(), | ||
"Max AE": ppsci.metric.MaxAE(), | ||
"R²": ppsci.metric.R2Score(), | ||
}, | ||
name="DrivAerNetPlusPlus_valid", | ||
) | ||
|
||
validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid} | ||
|
||
solver = ppsci.solver.Solver( | ||
model=model, | ||
validator=validator, | ||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||
) | ||
|
||
# evaluate model | ||
solver.eval() | ||
|
||
|
||
@hydra.main( | ||
version_base=None, config_path="./conf", config_name="drivaernetplusplus.yaml" | ||
) | ||
def main(cfg: DictConfig): | ||
warnings.filterwarnings("ignore") | ||
if cfg.mode == "train": | ||
train(cfg) | ||
elif cfg.mode == "eval": | ||
evaluate(cfg) | ||
else: | ||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整个文档是否能用 vscode 的markdown插件格式化一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已格式化