15
15
from __future__ import annotations
16
16
17
17
import csv
18
+ import os
18
19
from typing import Dict
19
20
from typing import Optional
20
21
from typing import Tuple
31
32
32
33
33
34
def save_csv_file (
34
- file_path : str ,
35
- data_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
35
+ filename : str ,
36
+ data_dict : Dict [str , Union [np .ndarray , " paddle.Tensor" ]],
36
37
keys : Tuple [str , ...],
37
38
alias_dict : Optional [Dict [str , str ]] = None ,
38
39
use_header : bool = True ,
@@ -42,11 +43,11 @@ def save_csv_file(
42
43
"""Write numpy data to csv file.
43
44
44
45
Args:
45
- file_path (str): Dump file path.
46
+ filename (str): Dump file path.
46
47
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Numpy data in dict.
47
48
keys (Tuple[str, ...]): Keys for data_dict to be fetched.
48
49
alias_dict (Optional[Dict[str, str]], optional): Alias dict for keys,
49
- i.e. {dict_key: dump_key }. Defaults to None.
50
+ i.e. {dump_key: dict_key }. Defaults to None.
50
51
use_header (bool, optional): Whether save csv with header. Defaults to True.
51
52
delimiter (str, optional): Delemiter for splitting different data field. Defaults to ",".
52
53
encoding (str, optional): Encoding. Defaults to "utf-8".
@@ -61,14 +62,14 @@ def save_csv_file(
61
62
>>> save_csv_file(
62
63
... "test.csv",
63
64
... data_dict,
64
- ... ("a ", "b "),
65
- ... alias_dict={"a ": "A ", "b ": "B "},
65
+ ... ("A ", "B "),
66
+ ... alias_dict={"A ": "a ", "B ": "b "},
66
67
... use_header=True,
67
68
... delimiter=",",
68
69
... encoding="utf-8",
69
70
... )
70
71
>>> # == test.csv ==
71
- >>> # a,b
72
+ >>> # A,B
72
73
>>> # 1,4.12
73
74
>>> # 2,5.25
74
75
>>> # 3,6.337
@@ -81,28 +82,120 @@ def save_csv_file(
81
82
data_fields = []
82
83
header = []
83
84
for key in keys :
84
- if key not in data_dict :
85
- raise KeyError (f"key({ key } ) do not exist in data_dict." )
86
-
87
- data = data_dict [key ]
85
+ fetch_key = alias_dict [key ] if key in alias_dict else key
86
+ data = data_dict [fetch_key ]
88
87
if isinstance (data , paddle .Tensor ):
89
88
data = data .numpy () # [num_of_samples, ]
90
89
91
90
data = data .flatten ()
92
91
data_fields .append (data )
93
92
94
- dump_key = alias_dict [key ] if key in alias_dict else key
95
- header .append (dump_key )
93
+ header .append (key )
96
94
97
95
assert len (header ) == len (data_fields )
98
96
99
- data_fields = zip (* data_fields )
100
- with open (file_path , "w" , newline = "" , encoding = encoding ) as file :
97
+ data_fields = zip (* data_fields ) # transpose col data to row data
98
+ with open (filename , "w" , newline = "" , encoding = encoding ) as file :
101
99
writer = csv .writer (file , delimiter = delimiter )
102
100
103
101
if use_header :
104
102
writer .writerow (header )
105
103
106
104
writer .writerows (data_fields )
107
105
108
- logger .message (f"csv file has been dumped to { file_path } " )
106
+ logger .message (f"csv file has been dumped to { filename } " )
107
+
108
+
109
+ def save_tecplot_file (
110
+ filename : str ,
111
+ data_dict : Dict [str , Union [np .ndarray , "paddle.Tensor" ]],
112
+ keys : Tuple [str , ...],
113
+ alias_dict : Optional [Dict [str , str ]] = None ,
114
+ delimiter : str = " " ,
115
+ encoding : str = "utf-8" ,
116
+ num_timestamps : int = 1 ,
117
+ ):
118
+ """Write numpy data to tecplot file.
119
+
120
+ Args:
121
+ filename (str): Tecplot file path.
122
+ data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Numpy or Tensor data in dict.
123
+ keys (Tuple[str, ...]): Target keys to be dumped.
124
+ alias_dict (Optional[Dict[str, str]], optional): Alias dict for keys,
125
+ i.e. {dump_key: dict_key}. Defaults to None.
126
+ delimiter (str, optional): Delemiter for splitting different data field. Defaults to ",".
127
+ encoding (str, optional): Encoding. Defaults to "utf-8".
128
+ num_timestamps (int, optional): Number of timestamp over coord and value. Defaults to 1.
129
+
130
+ Examples:
131
+ >>> import numpy as np
132
+ >>> from ppsci.utils import save_tecplot_file
133
+ >>> data_dict = {
134
+ ... "x": np.array([[1.0], [2.0], [10.0], [20.0], [100.0], [200.0]]), # [6, 1]
135
+ ... "y": np.array([[-1.0], [-2.0], [-10.0], [-20.0], [-100.0], [-200.0]]) # [6, 1]
136
+ ... }
137
+ >>> save_tecplot_file(
138
+ ... "./test.tec",
139
+ ... data_dict,
140
+ ... ("X", "Y"),
141
+ ... alias_dict={"X": "x", "Y": "y"},
142
+ ... num_timestamps=3,
143
+ ... )
144
+ >>> # == test_t-0.tec ==
145
+ >>> # title="./test_t-0.tec"
146
+ >>> # variables = "X", "Y"
147
+ >>> # Zone I = 2, J = 1, F = POINT
148
+ >>> # -1.0 1.0
149
+ >>> # -2.0 2.0
150
+
151
+ >>> # == test_t-1.tec ==
152
+ >>> # title="./test_t-1.tec"
153
+ >>> # variables = "X", "Y"
154
+ >>> # Zone I = 2, J = 1, F = POINT
155
+ >>> # -10.0 10.0
156
+ >>> # -20.0 20.0
157
+
158
+ >>> # == test_t-2.tec ==
159
+ >>> # title="./test_t-2.tec"
160
+ >>> # variables = "X", "Y"
161
+ >>> # Zone I = 2, J = 1, F = POINT
162
+ >>> # -100.0 100.0
163
+ >>> # -200.0 200.0
164
+ """
165
+ ntxy = len (next (iter (data_dict .values ())))
166
+ if ntxy % num_timestamps != 0 :
167
+ raise ValueError (
168
+ f"num_points({ ntxy } ) must be a multiple of "
169
+ f"num_timestamps({ num_timestamps } )."
170
+ )
171
+ nxy = ntxy // num_timestamps
172
+
173
+ os .makedirs (os .path .dirname (filename ), exist_ok = True )
174
+
175
+ if filename .endswith (".tec" ):
176
+ filename = filename [:- 4 ]
177
+
178
+ for t in range (num_timestamps ):
179
+ # write 1 tecplot file for each timestep
180
+ if num_timestamps > 1 :
181
+ dump_filename = f"{ filename } _t-{ t } .tec"
182
+ else :
183
+ dump_filename = f"{ filename } .tec"
184
+
185
+ with open (dump_filename , "w" , encoding = encoding ) as f :
186
+ # write meta information of tec
187
+ f .write (f'title="{ dump_filename } "\n ' )
188
+ fetch_keys = {(alias_dict [key ] if alias_dict else key ) for key in keys }
189
+ header = ", " .join ([f'"{ key } "' for key in keys ])
190
+ f .write (f"variables = { header } \n " )
191
+ f .write (f"Zone I = { nxy } , J = 1, F = POINT\n " )
192
+
193
+ # write points data into file
194
+ data_cur_time_step = [
195
+ data_dict [key ][t * nxy : (t + 1 ) * nxy ] for key in fetch_keys
196
+ ]
197
+
198
+ for items in zip (* data_cur_time_step ):
199
+ f .write (delimiter .join ([str (float (x )) for x in items ]) + "\n " )
200
+
201
+ logger .message (f"csv file has been dumped to { dump_filename } " )
0 commit comments