Skip to content

Commit a9910b1

Browse files
[Fea] Add writer module (#719)
* (WIP) add utils.writer module for saving data into disk * add complete code * update save_csv_file for compatible with different dtype
1 parent de84248 commit a9910b1

File tree

6 files changed

+195
-0
lines changed

6 files changed

+195
-0
lines changed

docs/zh/api/utils/writer.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Utils.reader(文件保存函数) 模块
2+
3+
::: ppsci.utils.reader
4+
handler: python
5+
options:
6+
members:
7+
- save_csv_file
8+
show_root_heading: True
9+
heading_level: 3

docs/zh/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Utils 模块内部存放了一些适用于多种场景下的工具类、函数
108108
| [ppsci.utils.logger](./api/utils/logger.md)| 日志打印模块 |
109109
| [ppsci.utils.misc](./api/utils/misc.md)| 存放通用函数 |
110110
| [ppsci.utils.reader](./api/utils/reader.md)| 文件读取模块 |
111+
| [ppsci.utils.writer](./api/utils/writer.md)| 文件写入模块 |
111112
| [ppsci.utils.save_load](./api/utils/save_load.md)| 模型参数保存与加载 |
112113
| [ppsci.utils.symbolic](./api/utils/symbolic.md)| sympy 符号计算功能相关 |
113114

ppsci/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ppsci.utils import misc
1818
from ppsci.utils import profiler
1919
from ppsci.utils import reader
20+
from ppsci.utils import writer
2021
from ppsci.utils.checker import dynamic_import_to_globals
2122
from ppsci.utils.checker import run_check
2223
from ppsci.utils.checker import run_check_mesh
@@ -33,6 +34,7 @@
3334
from ppsci.utils.save_load import load_pretrain
3435
from ppsci.utils.save_load import save_checkpoint
3536
from ppsci.utils.symbolic import lambdify
37+
from ppsci.utils.writer import save_csv_file
3638

3739
__all__ = [
3840
"AttrDict",
@@ -42,12 +44,14 @@
4244
"logger",
4345
"misc",
4446
"reader",
47+
"writer",
4548
"profiler",
4649
"load_csv_file",
4750
"load_mat_file",
4851
"load_npz_file",
4952
"load_vtk_file",
5053
"load_vtk_with_time_file",
54+
"save_csv_file",
5155
"dynamic_import_to_globals",
5256
"run_check",
5357
"run_check_mesh",

ppsci/utils/writer.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import csv
18+
from typing import Dict
19+
from typing import Optional
20+
from typing import Tuple
21+
from typing import Union
22+
23+
import numpy as np
24+
import paddle
25+
26+
from ppsci.utils import logger
27+
28+
__all__ = [
29+
"save_csv_file",
30+
]
31+
32+
33+
def save_csv_file(
34+
file_path: str,
35+
data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
36+
keys: Tuple[str, ...],
37+
alias_dict: Optional[Dict[str, str]] = None,
38+
use_header: bool = True,
39+
delimiter: str = ",",
40+
encoding: str = "utf-8",
41+
):
42+
"""Write numpy data to csv file.
43+
44+
Args:
45+
file_path (str): Dump file path.
46+
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Numpy data in dict.
47+
keys (Tuple[str, ...]): Keys for data_dict to be fetched.
48+
alias_dict (Optional[Dict[str, str]], optional): Alias dict for keys,
49+
i.e. {dict_key: dump_key}. Defaults to None.
50+
use_header (bool, optional): Whether save csv with header. Defaults to True.
51+
delimiter (str, optional): Delemiter for splitting different data field. Defaults to ",".
52+
encoding (str, optional): Encoding. Defaults to "utf-8".
53+
54+
Examples:
55+
>>> import numpy as np
56+
>>> from ppsci.utils import save_csv_file
57+
>>> data_dict = {
58+
... "a": np.array([[1], [2], [3]]).astype("int64"), # [3, 1]
59+
... "b": np.array([[4.12], [5.25], [6.3370]]).astype("float32"), # [3, 1]
60+
... }
61+
>>> save_csv_file(
62+
... "test.csv",
63+
... data_dict,
64+
... ("a", "b"),
65+
... alias_dict={"a": "A", "b": "B"},
66+
... use_header=True,
67+
... delimiter=",",
68+
... encoding="utf-8",
69+
... )
70+
>>> # == test.csv ==
71+
>>> # a,b
72+
>>> # 1,4.12
73+
>>> # 2,5.25
74+
>>> # 3,6.337
75+
"""
76+
77+
if alias_dict is None:
78+
alias_dict = {}
79+
80+
# convert to numpy array
81+
data_fields = []
82+
header = []
83+
for key in keys:
84+
if key not in data_dict:
85+
raise KeyError(f"key({key}) do not exist in data_dict.")
86+
87+
data = data_dict[key]
88+
if isinstance(data, paddle.Tensor):
89+
data = data.numpy() # [num_of_samples, ]
90+
91+
data = data.flatten()
92+
data_fields.append(data)
93+
94+
dump_key = alias_dict[key] if key in alias_dict else key
95+
header.append(dump_key)
96+
97+
assert len(header) == len(data_fields)
98+
99+
data_fields = zip(*data_fields)
100+
with open(file_path, "w", newline="", encoding=encoding) as file:
101+
writer = csv.writer(file, delimiter=delimiter)
102+
103+
if use_header:
104+
writer.writerow(header)
105+
106+
writer.writerows(data_fields)
107+
108+
logger.message(f"csv file has been dumped to {file_path}")

test/utils/test_symbolic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import numpy as np
216
import paddle
317
import pytest

test/utils/test_writer.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import numpy as np
17+
import pytest
18+
19+
from ppsci.utils import reader
20+
from ppsci.utils import writer
21+
22+
23+
def test_save_csv_file():
24+
keys = ["x1", "y1", "z1"]
25+
alias_dict = (
26+
{
27+
"x": "x1",
28+
"y": "y1",
29+
"z": "z1",
30+
},
31+
)
32+
data_dict = {
33+
keys[0]: np.random.randint(0, 255, (10, 1)),
34+
keys[1]: np.random.rand(10, 1),
35+
keys[2]: np.random.rand(10, 1),
36+
}
37+
file_path = "test_writer.csv"
38+
writer.save_csv_file(
39+
file_path,
40+
data_dict,
41+
keys,
42+
alias_dict=alias_dict,
43+
use_header=True,
44+
)
45+
46+
reload_data_dict = reader.load_csv_file(
47+
file_path,
48+
keys,
49+
alias_dict,
50+
)
51+
52+
assert data_dict.keys() == reload_data_dict.keys()
53+
for k in reload_data_dict:
54+
assert reload_data_dict[k].shape == data_dict[k].shape
55+
assert np.allclose(reload_data_dict[k], data_dict[k])
56+
57+
58+
if __name__ == "__main__":
59+
pytest.main()

0 commit comments

Comments
 (0)