Skip to content

Commit 3b1d05d

Browse files
fix lorenz/rossler export and infer (#805)
* add embedding model to PhysformerGPT2 for infer * modify export and inference code of lorenz * fix export command * fix export and infer of rossler * fix doc * fix error message in generate method * fix docstring
1 parent 237ff01 commit 3b1d05d

File tree

5 files changed

+45
-36
lines changed

5 files changed

+45
-36
lines changed

docs/zh/examples/lorenz.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
=== "模型导出命令"
3232

3333
``` sh
34-
python train_enn.py mode=export
34+
python train_transformer.py mode=export EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/lorenz/lorenz_pretrained.pdparams
3535
```
3636

3737
=== "模型推理命令"
@@ -43,7 +43,7 @@
4343
# windows
4444
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/lorenz_training_rk.hdf5 --output ./datasets/lorenz_training_rk.hdf5
4545
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/lorenz_valid_rk.hdf5 --output ./datasets/lorenz_valid_rk.hdf5
46-
python train_transformer.py mode=infer EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/lorenz/lorenz_pretrained.pdparams
46+
python train_transformer.py mode=infer
4747
```
4848

4949
| 模型 | MSE |

docs/zh/examples/rossler.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
=== "模型导出命令"
3232

3333
``` sh
34-
python train_transformer.py mode=export
34+
python train_transformer.py mode=export EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/rossler/rossler_pretrained.pdparams
3535
```
3636

3737
=== "模型推理命令"
@@ -43,7 +43,7 @@
4343
# windows
4444
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/rossler_training.hdf5 --output ./datasets/rossler_training.hdf5
4545
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/rossler_valid.hdf5 --output ./datasets/rossler_valid.hdf5
46-
python train_transformer.py mode=infer EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/rossler/rossler_pretrained.pdparams
46+
python train_transformer.py mode=infer
4747
```
4848

4949
| 模型 | MSE |

examples/lorenz/train_transformer.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,14 @@ def evaluate(cfg: DictConfig):
247247

248248
def export(cfg: DictConfig):
249249
# set model
250-
model = ppsci.arch.PhysformerGPT2(**cfg.MODEL)
250+
embedding_model = build_embedding_model(cfg.EMBEDDING_MODEL_PATH)
251+
model_cfg = {
252+
**cfg.MODEL,
253+
"embedding_model": embedding_model,
254+
"input_keys": ["states"],
255+
"output_keys": ["pred_states"],
256+
}
257+
model = ppsci.arch.PhysformerGPT2(**model_cfg)
251258

252259
# initialize solver
253260
solver = ppsci.solver.Solver(
@@ -259,7 +266,7 @@ def export(cfg: DictConfig):
259266

260267
input_spec = [
261268
{
262-
key: InputSpec([None, 256, 32], "float32", name=key)
269+
key: InputSpec([None, 255, 3], "float32", name=key)
263270
for key in model.input_keys
264271
},
265272
]
@@ -272,42 +279,33 @@ def inference(cfg: DictConfig):
272279

273280
predictor = pinn_predictor.PINNPredictor(cfg)
274281

275-
embedding_model = build_embedding_model(cfg.EMBEDDING_MODEL_PATH)
276-
output_transform = OutputTransform(embedding_model)
277282
dataset_cfg = {
278283
"name": "LorenzDataset",
279284
"file_path": cfg.VALID_FILE_PATH,
280285
"input_keys": cfg.MODEL.input_keys,
281286
"label_keys": cfg.MODEL.output_keys,
282287
"block_size": cfg.VALID_BLOCK_SIZE,
283288
"stride": 1024,
284-
"embedding_model": embedding_model,
285289
}
286290

287291
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
288292

289293
input_dict = {
290-
"embeds": dataset.embedding_data[: cfg.VIS_DATA_NUMS, :-1, :],
294+
"states": dataset.data[: cfg.VIS_DATA_NUMS, :-1, :],
291295
}
292-
293-
output_dict = predictor.predict(
294-
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
295-
)
296+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
296297

297298
# mapping data to cfg.INFER.output_keys
299+
output_keys = ["pred_states"]
298300
output_dict = {
299-
store_key: paddle.to_tensor(output_dict[infer_key])
300-
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
301+
store_key: output_dict[infer_key]
302+
for store_key, infer_key in zip(output_keys, output_dict.keys())
301303
}
302304

303305
input_dict = {
304306
"states": dataset.data[: cfg.VIS_DATA_NUMS, 1:, :],
305307
}
306308

307-
output_dict = {
308-
"pred_states": output_transform(output_dict).numpy(),
309-
}
310-
311309
data_dict = {**input_dict, **output_dict}
312310
for i in range(cfg.VIS_DATA_NUMS):
313311
ppsci.visualize.save_plot_from_3d_dict(

examples/rossler/train_transformer.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,14 @@ def evaluate(cfg: DictConfig):
246246

247247
def export(cfg: DictConfig):
248248
# set model
249-
model = ppsci.arch.PhysformerGPT2(**cfg.MODEL)
249+
embedding_model = build_embedding_model(cfg.EMBEDDING_MODEL_PATH)
250+
model_cfg = {
251+
**cfg.MODEL,
252+
"embedding_model": embedding_model,
253+
"input_keys": ["states"],
254+
"output_keys": ["pred_states"],
255+
}
256+
model = ppsci.arch.PhysformerGPT2(**model_cfg)
250257

251258
# initialize solver
252259
solver = ppsci.solver.Solver(
@@ -258,7 +265,7 @@ def export(cfg: DictConfig):
258265

259266
input_spec = [
260267
{
261-
key: InputSpec([None, 256, 32], "float32", name=key)
268+
key: InputSpec([None, 255, 3], "float32", name=key)
262269
for key in model.input_keys
263270
},
264271
]
@@ -271,42 +278,34 @@ def inference(cfg: DictConfig):
271278

272279
predictor = pinn_predictor.PINNPredictor(cfg)
273280

274-
embedding_model = build_embedding_model(cfg.EMBEDDING_MODEL_PATH)
275-
output_transform = OutputTransform(embedding_model)
276281
dataset_cfg = {
277282
"name": "RosslerDataset",
278283
"file_path": cfg.VALID_FILE_PATH,
279284
"input_keys": cfg.MODEL.input_keys,
280285
"label_keys": cfg.MODEL.output_keys,
281286
"block_size": cfg.VALID_BLOCK_SIZE,
282287
"stride": 1024,
283-
"embedding_model": embedding_model,
284288
}
285289

286290
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
287291

288292
input_dict = {
289-
"embeds": dataset.embedding_data[: cfg.VIS_DATA_NUMS, :-1, :],
293+
"states": dataset.data[: cfg.VIS_DATA_NUMS, :-1, :],
290294
}
291295

292-
output_dict = predictor.predict(
293-
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
294-
)
296+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
295297

296298
# mapping data to cfg.INFER.output_keys
299+
output_keys = ["pred_states"]
297300
output_dict = {
298-
store_key: paddle.to_tensor(output_dict[infer_key])
299-
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
301+
store_key: output_dict[infer_key]
302+
for store_key, infer_key in zip(output_keys, output_dict.keys())
300303
}
301304

302305
input_dict = {
303306
"states": dataset.data[: cfg.VIS_DATA_NUMS, 1:, :],
304307
}
305308

306-
output_dict = {
307-
"pred_states": output_transform(output_dict).numpy(),
308-
}
309-
310309
data_dict = {**input_dict, **output_dict}
311310
for i in range(cfg.VIS_DATA_NUMS):
312311
ppsci.visualize.save_plot_from_3d_dict(

ppsci/arch/physx_transformer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ class PhysformerGPT2(base.Arch):
251251
attn_pdrop (float, optional): The dropout probability used on attention weights. Defaults to 0.0.
252252
resid_pdrop (float, optional): The dropout probability used on block outputs. Defaults to 0.0.
253253
initializer_range (float, optional): Initializer range of linear layer. Defaults to 0.05.
254+
embedding_model (Optional[base.Arch]): Embedding model, If this parameter is set,
255+
the embedding model will map the input data to the embedding space and the
256+
output data to the physical space. Defaults to None.
254257
255258
Examples:
256259
>>> import ppsci
@@ -269,6 +272,7 @@ def __init__(
269272
attn_pdrop: float = 0.0,
270273
resid_pdrop: float = 0.0,
271274
initializer_range: float = 0.05,
275+
embedding_model: Optional[base.Arch] = None,
272276
):
273277
super().__init__()
274278
self.input_keys = input_keys
@@ -296,6 +300,7 @@ def __init__(
296300
self.linear = nn.Linear(embed_size, embed_size)
297301

298302
self.apply(self._init_weights)
303+
self.embedding_model = embedding_model
299304

300305
def _init_weights(self, module):
301306
if isinstance(module, nn.Linear):
@@ -344,7 +349,7 @@ def _generate_time_series(self, x, max_length):
344349
def generate(self, x, max_length=256):
345350
if max_length <= 0:
346351
raise ValueError(
347-
"max_length({max_length}) should be a strictly positive integer."
352+
f"max_length({max_length}) should be a strictly positive integer."
348353
)
349354
outputs = self._generate_time_series(x, max_length)
350355
return outputs
@@ -375,10 +380,17 @@ def forward(self, x):
375380
if self._input_transform is not None:
376381
x = self._input_transform(x)
377382
x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
383+
if self.embedding_model is not None:
384+
x_tensor = self.embedding_model.encoder(x_tensor)
385+
378386
if self.training:
379387
y = self.forward_tensor(x_tensor)
380388
else:
381389
y = self.forward_eval(x_tensor)
390+
391+
if self.embedding_model is not None:
392+
y = (self.embedding_model.decoder(y[0]),)
393+
382394
y = self.split_to_dict(y, self.output_keys)
383395
if self._output_transform is not None:
384396
y = self._output_transform(x, y)

0 commit comments

Comments
 (0)