Skip to content

Commit efc5cdd

Browse files
committed
add some code for DGMR
1 parent c261eff commit efc5cdd

File tree

8 files changed

+1628
-0
lines changed

8 files changed

+1628
-0
lines changed

docs/zh/api/arch.md

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- USCNN
2525
- NowcastNet
2626
- HEDeepONets
27+
- DGMR
2728
- ChipDeepONets
2829
- AutoEncoder
2930
show_root_heading: true

docs/zh/api/data/dataset.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
- MeshCylinderDataset
2424
- RadarDataset
2525
- build_dataset
26+
- DGMRDataset
2627
show_root_heading: true

examples/dgmr/conf/dgmr.yaml

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
hydra:
2+
run:
3+
# dynamic output directory according to running time and override name
4+
dir: outputs_dgmr/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
5+
job:
6+
name: ${mode} # name of logfile
7+
chdir: false # keep current working direcotry unchaned
8+
config:
9+
override_dirname:
10+
exclude_keys:
11+
- TRAIN.checkpoint_path
12+
- TRAIN.pretrained_model_path
13+
- EVAL.pretrained_model_path
14+
- mode
15+
- output_dir
16+
- log_freq
17+
sweep:
18+
# output directory for multirun
19+
dir: ${hydra.run.dir}
20+
subdir: ./
21+
22+
# general settings
23+
mode: eval # running mode: train/eval
24+
seed: 42
25+
output_dir: ${hydra:run.dir}
26+
27+
# dataset settings
28+
DATASET:
29+
split: validation # train or validation
30+
NUM_INPUT_FRAMES: 4
31+
NUM_TARGET_FRAMES: 18
32+
dataset_path: /workspace/workspace/skillful_nowcasting/openclimatefix/nimrod-uk-1km
33+
number: 10
34+
35+
# dataLoader settings
36+
DATALOADER:
37+
batch_size: 1
38+
shuffle: False
39+
num_workers: 1
40+
drop_last: True
41+
42+
# model settings
43+
MODEL:
44+
forecast_steps: 18
45+
input_channels: 1
46+
output_shape: 256
47+
gen_lr: 5e-05
48+
disc_lr: 0.0002
49+
visualize: False
50+
conv_type: 'standard'
51+
num_samples: 6
52+
grid_lambda: 20.0
53+
beta1: 0.0
54+
beta2: 0.999
55+
latent_channels: 768
56+
context_channels: 384
57+
generation_steps: 6
58+
59+
# evaluation settings
60+
EVAL:
61+
pretrained_model_path: /workspace/workspace/skillful_nowcasting/openclimatefix/paddle/paddle_model.pdparams

examples/dgmr/dgmr.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Reference: https://github.com/openclimatefix/skillful_nowcasting
17+
"""
18+
from os import path as osp
19+
20+
import hydra
21+
import matplotlib.pyplot as plt
22+
import numpy as np
23+
import paddle
24+
from omegaconf import DictConfig
25+
26+
import ppsci
27+
from ppsci.utils import logger
28+
29+
30+
def visualize(
31+
cfg: DictConfig,
32+
x: paddle.Tensor,
33+
y: paddle.Tensor,
34+
y_hat: paddle.Tensor,
35+
batch_idx: int,
36+
) -> None:
37+
images = x[0]
38+
future_images = y[0]
39+
generated_images = y_hat[0]
40+
fig, axes = plt.subplots(2, 2)
41+
for i, ax in enumerate(axes.flat):
42+
alpha = images[i][0].numpy()
43+
alpha[alpha < 1] = 0
44+
alpha[alpha > 1] = 1
45+
ax.imshow(images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis")
46+
ax.axis("off")
47+
plt.subplots_adjust(hspace=0.1, wspace=0.1)
48+
plt.savefig(osp.join(cfg.output_dir, "Input_Image_Stack_Frame.png"))
49+
fig, axes = plt.subplots(3, 3)
50+
for i, ax in enumerate(axes.flat):
51+
alpha = future_images[i][0].numpy()
52+
alpha[alpha < 1] = 0
53+
alpha[alpha > 1] = 1
54+
ax.imshow(
55+
future_images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis"
56+
)
57+
plt.subplots_adjust(hspace=0.1, wspace=0.1)
58+
plt.savefig(osp.join(cfg.output_dir, "Target_Image_Frame.png"))
59+
fig, axes = plt.subplots(3, 3)
60+
for i, ax in enumerate(axes.flat):
61+
alpha = generated_images[i][0].numpy()
62+
alpha[alpha < 1] = 0
63+
alpha[alpha > 1] = 1
64+
ax.imshow(
65+
generated_images[i].transpose([1, 2, 0]).numpy(),
66+
alpha=alpha,
67+
cmap="viridis",
68+
)
69+
ax.axis("off")
70+
plt.subplots_adjust(hspace=0.1, wspace=0.1)
71+
plt.savefig(osp.join(cfg.output_dir, "Generated_Image_Frame.png"))
72+
73+
74+
def train(cfg: DictConfig):
75+
print("Not supported.")
76+
77+
78+
def evaluate(cfg: DictConfig):
79+
# set model
80+
model = ppsci.arch.DGMR(**cfg.MODEL)
81+
# load evaluate data
82+
dataset = ppsci.data.dataset.DGMRDataset(**cfg.DATASET)
83+
val_loader = paddle.io.DataLoader(dataset, batch_size=cfg.DATALOADER.batch_size)
84+
# initialize solver
85+
solver = ppsci.solver.Solver(
86+
model,
87+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
88+
)
89+
solver.model.eval()
90+
91+
# evaluate pretrained model
92+
d_loss = []
93+
g_loss = []
94+
grid_loss = []
95+
for batch_idx, batch in enumerate(val_loader):
96+
with paddle.no_grad():
97+
out_dict = solver.model.validation_step(batch, batch_idx)
98+
# visualize
99+
images, future_images = batch
100+
images = images.astype(dtype="float32")
101+
future_images = future_images.astype(dtype="float32")
102+
generated_images = solver.model.generator(images)
103+
visualize(cfg, images, future_images, generated_images, batch_idx)
104+
d_loss.append(out_dict[0])
105+
g_loss.append(out_dict[1])
106+
grid_loss.append(out_dict[2])
107+
logger.message(f"d_loss: {np.array(d_loss).mean()}")
108+
logger.message(f"g_loss: {np.array(g_loss).mean()}")
109+
logger.message(f"grid_loss: {np.array(grid_loss).mean()}")
110+
111+
112+
@hydra.main(version_base=None, config_path="./conf", config_name="dgmr.yaml")
113+
def main(cfg: DictConfig):
114+
if cfg.mode == "train":
115+
train(cfg)
116+
elif cfg.mode == "eval":
117+
evaluate(cfg)
118+
else:
119+
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
120+
121+
122+
if __name__ == "__main__":
123+
main()

ppsci/arch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ppsci.arch.he_deeponets import HEDeepONets # isort:skip
4141
from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip
4242
from ppsci.arch.cfdgcn import CFDGCN # isort:skip
43+
from ppsci.arch.dgmr import DGMR # isort:skip
4344
from ppsci.arch.vae import AutoEncoder # isort:skip
4445
from ppsci.utils import logger # isort:skip
4546

@@ -67,6 +68,7 @@
6768
"USCNN",
6869
"HEDeepONets",
6970
"ChipDeepONets",
71+
"DGMR",
7072
"AutoEncoder",
7173
"build_model",
7274
"CFDGCN",

0 commit comments

Comments
 (0)