Skip to content

Commit 24e0a83

Browse files
update code
1 parent 9630956 commit 24e0a83

File tree

1 file changed

+46
-30
lines changed

1 file changed

+46
-30
lines changed

ppsci/utils/writer.py

+46-30
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def save_csv_file(
8888
if isinstance(data, paddle.Tensor):
8989
data = data.numpy() # [num_of_samples, ]
9090

91-
data = data.flatten()
91+
if isinstance(data, np.ndarray):
92+
data = data.flatten()
9293
data_fields.append(data)
9394

9495
header.append(key)
@@ -111,58 +112,62 @@ def save_tecplot_file(
111112
filename: str,
112113
data_dict: Dict[str, Union[np.ndarray, "paddle.Tensor"]],
113114
keys: Tuple[str, ...],
115+
num_x: int,
116+
num_y: int,
114117
alias_dict: Optional[Dict[str, str]] = None,
115118
delimiter: str = " ",
116119
encoding: str = "utf-8",
117120
num_timestamps: int = 1,
118121
):
119-
"""Write numpy data to tecplot file.
122+
"""Write numpy or tensor data to tecplot file.
120123
121124
Args:
122125
filename (str): Tecplot file path.
123126
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Numpy or Tensor data in dict.
124127
keys (Tuple[str, ...]): Target keys to be dumped.
128+
num_x (int): The number of discrete points of the grid in the X-axis. Assuming
129+
the discrete grid size is 20 x 30, then num_x=20.
130+
num_y (int): The number of discrete points of the grid in the Y-axis. Assuming
131+
the discrete grid size is 20 x 30, then num_y=30.
125132
alias_dict (Optional[Dict[str, str]], optional): Alias dict for keys,
126133
i.e. {dump_key: dict_key}. Defaults to None.
127-
delimiter (str, optional): Delemiter for splitting different data field. Defaults to ",".
134+
delimiter (str, optional): Delemiter for splitting different data field. Defaults to " ".
128135
encoding (str, optional): Encoding. Defaults to "utf-8".
129136
num_timestamps (int, optional): Number of timestamp over coord and value. Defaults to 1.
130137
131138
Examples:
132139
>>> import numpy as np
133140
>>> from ppsci.utils import save_tecplot_file
134141
>>> data_dict = {
135-
... "x": np.array([[1.0], [2.0], [10.0], [20.0], [100.0], [200.0]]), # [6, 1]
136-
... "y": np.array([[-1.0], [-2.0], [-10.0], [-20.0], [-100.0], [-200.0]]) # [6, 1]
142+
... "x": np.array([[-1.0], [-1.0], [-1.0], [-1.0], [-1.0], [-1.0]]), # [6, 1]
143+
... "y": np.array([[1.0], [2.0], [3.0], [1.0], [2.0], [3.0]]), # [6, 1]
144+
... "value": np.array([[3], [33], [333], [3333], [33333], [333333]]), # [6, 1]
137145
... }
138146
>>> save_tecplot_file(
139-
... "./test.tec",
147+
... "./test.dat",
140148
... data_dict,
141149
... ("X", "Y"),
150+
... num_x=1,
151+
... num_y=3,
142152
... alias_dict={"X": "x", "Y": "y"},
143-
... num_timestamps=3,
153+
... num_timestamps=2,
144154
... )
145-
146-
>>> # == test_t-0.tec ==
147-
>>> # title="./test_t-0.tec"
155+
>>> # == test_t-0.dat ==
156+
>>> # title = "./test_t-0.dat"
148157
>>> # variables = "X", "Y"
149-
>>> # Zone I = 2, J = 1, F = POINT
158+
>>> # Zone I = 3, J = 1, F = POINT
150159
>>> # -1.0 1.0
151-
>>> # -2.0 2.0
160+
>>> # -1.0 2.0
161+
>>> # -1.0 3.0
152162
153-
>>> # == test_t-1.tec ==
154-
>>> # title="./test_t-1.tec"
155-
>>> # variables = "X", "Y"
156-
>>> # Zone I = 2, J = 1, F = POINT
157-
>>> # -10.0 10.0
158-
>>> # -20.0 20.0
159163
160-
>>> # == test_t-2.tec ==
161-
>>> # title="./test_t-2.tec"
164+
>>> # == test_t-1.dat ==
165+
>>> # title = "./test_t-1.dat"
162166
>>> # variables = "X", "Y"
163-
>>> # Zone I = 2, J = 1, F = POINT
164-
>>> # -100.0 100.0
165-
>>> # -200.0 200.0
167+
>>> # Zone I = 3, J = 1, F = POINT
168+
>>> # -1.0 1.0
169+
>>> # -1.0 2.0
170+
>>> # -1.0 3.0
166171
"""
167172
ntxy = len(next(iter(data_dict.values())))
168173
if ntxy % num_timestamps != 0:
@@ -172,25 +177,31 @@ def save_tecplot_file(
172177
)
173178
nxy = ntxy // num_timestamps
174179

180+
nx, ny = num_x, num_y
181+
assert nx * ny == nxy, f"nx({nx}) * ny({ny}) != nxy({nxy})"
182+
175183
os.makedirs(os.path.dirname(filename), exist_ok=True)
176184

177-
if filename.endswith(".tec"):
185+
if filename.endswith(".dat"):
178186
filename = filename[:-4]
179187

180188
for t in range(num_timestamps):
181189
# write 1 tecplot file for each timestep
182190
if num_timestamps > 1:
183-
dump_filename = f"{filename}_t-{t}.tec"
191+
dump_filename = f"{filename}_t-{t}.dat"
184192
else:
185-
dump_filename = f"{filename}.tec"
193+
dump_filename = f"{filename}.dat"
186194

195+
fetch_keys = [(alias_dict[key] if alias_dict else key) for key in keys]
187196
with open(dump_filename, "w", encoding=encoding) as f:
188197
# write meta information of tec
189-
f.write(f'title="{dump_filename}"\n')
190-
fetch_keys = {(alias_dict[key] if alias_dict else key) for key in keys}
198+
f.write(f'title = "{dump_filename}"\n')
191199
header = ", ".join([f'"{key}"' for key in keys])
192200
f.write(f"variables = {header}\n")
193-
f.write(f"Zone I = {nxy}, J = 1, F = POINT\n")
201+
202+
# NOTE: Tecplot is column-major, so we need to specify I = ny, J = nx,
203+
# which is in contrast to our habits.
204+
f.write(f"Zone I = {ny}, J = {nx}, F = POINT\n")
194205

195206
# write points data into file
196207
data_cur_time_step = [
@@ -200,4 +211,9 @@ def save_tecplot_file(
200211
for items in zip(*data_cur_time_step):
201212
f.write(delimiter.join([str(float(x)) for x in items]) + "\n")
202213

203-
logger.message(f"csv file has been dumped to {dump_filename}")
214+
if num_timestamps > 1:
215+
logger.message(
216+
f"tecplot files are saved to: {filename}_t-0.dat ~ {filename}_t-{num_timestamps - 1}.dat"
217+
)
218+
else:
219+
logger.message(f"tecplot file is saved to: {filename}.dat")

0 commit comments

Comments
 (0)