Skip to content

Commit 7c6f6aa

Browse files
authored
[Add]MRMSDataset (#810)
* [Add]MRMSDataset * fix code and style
1 parent 3e1d0ad commit 7c6f6aa

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

ppsci/data/dataset/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from ppsci.data.dataset.era5_dataset import ERA5SampledDataset
2626
from ppsci.data.dataset.mat_dataset import IterableMatDataset
2727
from ppsci.data.dataset.mat_dataset import MatDataset
28+
from ppsci.data.dataset.mrms_dataset import MRMSDataset
29+
from ppsci.data.dataset.mrms_dataset import MRMSSampledDataset
2830
from ppsci.data.dataset.npz_dataset import IterableNPZDataset
2931
from ppsci.data.dataset.npz_dataset import NPZDataset
3032
from ppsci.data.dataset.radar_dataset import RadarDataset
@@ -47,6 +49,8 @@
4749
"ERA5SampledDataset",
4850
"IterableMatDataset",
4951
"MatDataset",
52+
"MRMSDataset",
53+
"MRMSSampledDataset",
5054
"IterableNPZDataset",
5155
"NPZDataset",
5256
"CylinderDataset",

ppsci/data/dataset/mrms_dataset.py

+248
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import glob
18+
import os.path as osp
19+
from datetime import datetime
20+
from datetime import timedelta
21+
from typing import Dict
22+
from typing import List
23+
from typing import Optional
24+
from typing import Tuple
25+
26+
import h5py
27+
import numpy as np
28+
import paddle
29+
from paddle import io
30+
from paddle import vision
31+
32+
33+
class MRMSDataset(io.Dataset):
34+
"""Class for MRMS dataset. MRMS day's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
35+
36+
Args:
37+
file_path (str): Dataset path.
38+
input_keys (Tuple[str, ...]): Input keys, usually there is only one, such as ("input",).
39+
label_keys (Tuple[str, ...]): Output keys, usually there is only one, such as ("output",).
40+
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
41+
date_period (Tuple[str,...], optional): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d". Defaults to ("20230101","20230101").
42+
num_input_timestamps (int, optional): Number of timestamp of input. Defaults to 1.
43+
num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
44+
stride (int, optional): Stride of sampling data. Defaults to 1.
45+
transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
46+
47+
Examples:
48+
>>> import ppsci
49+
>>> dataset = ppsci.data.dataset.MRMSDataset(
50+
... "file_path": "/path/to/MRMSDataset",
51+
... "input_keys": ("input",),
52+
... "label_keys": ("output",),
53+
... "date_period": ("20230101","20230131"),
54+
... "num_input_timestamps": 9,
55+
... "num_label_timestamps": 20,
56+
... "transforms": transform,
57+
... "stride": 1,
58+
... ) # doctest: +SKIP
59+
"""
60+
61+
# Whether support batch indexing for speeding up fetching process.
62+
batch_index: bool = False
63+
64+
def __init__(
65+
self,
66+
file_path: str,
67+
input_keys: Tuple[str, ...],
68+
label_keys: Tuple[str, ...],
69+
weight_dict: Optional[Dict[str, float]] = None,
70+
date_period: Tuple[str, ...] = ("20230101", "20230101"),
71+
num_input_timestamps: int = 1,
72+
num_label_timestamps: int = 1,
73+
stride: int = 1,
74+
transforms: Optional[vision.Compose] = None,
75+
):
76+
super().__init__()
77+
self.file_path = file_path
78+
self.input_keys = input_keys
79+
self.label_keys = label_keys
80+
81+
self.weight_dict = {} if weight_dict is None else weight_dict
82+
if weight_dict is not None:
83+
self.weight_dict = {key: 1.0 for key in self.label_keys}
84+
self.weight_dict.update(weight_dict)
85+
86+
self.date_list = self._get_date_strs(date_period)
87+
self.num_input_timestamps = num_input_timestamps
88+
self.num_label_timestamps = num_label_timestamps
89+
self.stride = stride
90+
self.transforms = transforms
91+
92+
self.files = self._read_data(file_path)
93+
self.num_samples_per_day = self.files[0].shape[0]
94+
self.num_samples = self.num_samples_per_day * len(self.date_list)
95+
96+
def _get_date_strs(self, date_period: Tuple[str, ...]) -> List:
97+
"""Get a string list of all dates within given period.
98+
99+
Args:
100+
date_period (Tuple[str,...]): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d".
101+
"""
102+
start_time = datetime.strptime(date_period[0], "%Y%m%d")
103+
end_time = datetime.strptime(date_period[1], "%Y%m%d")
104+
results = []
105+
current_time = start_time
106+
while current_time <= end_time:
107+
date_str = current_time.strftime("%Y%m%d")
108+
results.append(date_str)
109+
current_time += timedelta(days=1)
110+
return results
111+
112+
def _read_data(self, path: str):
113+
if path.endswith(".h5"):
114+
paths = [path]
115+
else:
116+
paths = [
117+
_path
118+
for _path in glob.glob(osp.join(path, "*.h5"))
119+
if _path.split(".h5")[0].split("_")[-1] in self.date_list
120+
]
121+
assert len(paths) == len(
122+
self.date_list
123+
), f"Data of {len(self.date_list)} days wanted but only {len(paths)} days be found"
124+
paths.sort()
125+
126+
files = [h5py.File(_path, "r")["dataset"] for _path in paths]
127+
return files
128+
129+
def __len__(self):
130+
return (
131+
self.num_samples // self.stride
132+
- self.num_input_timestamps
133+
- self.num_label_timestamps
134+
+ 1
135+
)
136+
137+
def __getitem__(self, global_idx):
138+
global_idx *= self.stride
139+
_samples = np.empty(
140+
(
141+
self.num_input_timestamps + self.num_label_timestamps,
142+
*self.files[0].shape[1:],
143+
),
144+
dtype=paddle.get_default_dtype(),
145+
)
146+
for idx in range(self.num_input_timestamps + self.num_label_timestamps):
147+
sample_idx = global_idx + idx * self.stride
148+
day_idx = sample_idx // self.num_samples_per_day
149+
local_idx = sample_idx % self.num_samples_per_day
150+
_samples[idx] = self.files[day_idx][local_idx]
151+
152+
input_item = {self.input_keys[0]: _samples[: self.num_input_timestamps]}
153+
label_item = {self.label_keys[0]: _samples[self.num_input_timestamps :]}
154+
155+
weight_shape = [1] * len(next(iter(label_item.values())).shape)
156+
weight_item = {
157+
key: np.full(weight_shape, value, paddle.get_default_dtype())
158+
for key, value in self.weight_dict.items()
159+
}
160+
161+
if self.transforms is not None:
162+
input_item, label_item, weight_item = self.transforms(
163+
input_item, label_item, weight_item
164+
)
165+
166+
return input_item, label_item, weight_item
167+
168+
169+
class MRMSSampledDataset(io.Dataset):
170+
"""Class for MRMS sampled dataset. MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
171+
The class just return data by input_item and values of label_item are empty for all label_keys.
172+
173+
Args:
174+
file_path (str): Dataset path.
175+
input_keys (Tuple[str, ...]): Input keys, such as ("input",).
176+
label_keys (Tuple[str, ...]): Output keys, such as ("output",).
177+
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
178+
num_total_timestamps (int, optional): Number of timestamp of input+label. Defaults to 1.
179+
transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
180+
181+
Examples:
182+
>>> import ppsci
183+
>>> dataset = ppsci.data.dataset.MRMSSampledDataset(
184+
... "file_path": "/path/to/MRMSSampledDataset",
185+
... "input_keys": ("input",),
186+
... "label_keys": ("output",),
187+
... "num_total_timestamps": 29,
188+
... ) # doctest: +SKIP
189+
>>> # get the length of the dataset
190+
>>> dataset_size = len(dataset)
191+
>>> # get the first sample of the data
192+
>>> first_sample = dataset[0]
193+
>>> print("First sample:", first_sample)
194+
"""
195+
196+
def __init__(
197+
self,
198+
file_path: str,
199+
input_keys: Tuple[str, ...],
200+
label_keys: Tuple[str, ...],
201+
weight_dict: Optional[Dict[str, float]] = None,
202+
num_total_timestamps: int = 1,
203+
transforms: Optional[vision.Compose] = None,
204+
):
205+
super().__init__()
206+
self.file_path = file_path
207+
self.input_keys = input_keys
208+
self.label_keys = label_keys
209+
210+
self.weight_dict = {} if weight_dict is None else weight_dict
211+
if weight_dict is not None:
212+
self.weight_dict = {key: 1.0 for key in self.label_keys}
213+
self.weight_dict.update(weight_dict)
214+
215+
self.num_total_timestamps = num_total_timestamps
216+
self.transforms = transforms
217+
218+
self.files = self._read_data(file_path)
219+
self.num_samples = len(self.files)
220+
221+
def _read_data(self, path: str):
222+
paths = glob.glob(osp.join(path, "*.h5"))
223+
paths.sort()
224+
files = [h5py.File(_path, "r")["dataset"] for _path in paths]
225+
return files
226+
227+
def __len__(self):
228+
return self.num_samples - self.num_total_timestamps + 1
229+
230+
def __getitem__(self, global_idx):
231+
_samples = []
232+
for idx in range(global_idx, global_idx + self.num_total_timestamps):
233+
_samples.append(np.expand_dims(self.files[idx], axis=0))
234+
235+
input_item = {
236+
self.input_keys[0]: np.concatenate(_samples, axis=0).astype(
237+
paddle.get_default_dtype()
238+
)
239+
}
240+
label_item = {}
241+
weight_item = {}
242+
243+
if self.transforms is not None:
244+
input_item, label_item, weight_item = self.transforms(
245+
input_item, label_item, weight_item
246+
)
247+
248+
return input_item, label_item, weight_item

0 commit comments

Comments
 (0)