@@ -254,14 +254,80 @@ def u_solution_func(in_) -> np.ndarray:
254
254
plt .savefig (osp .join (cfg .output_dir , "./Volterra_IDE.png" ), dpi = 200 )
255
255
256
256
257
+ def export (cfg : DictConfig ):
258
+ # set model
259
+ model = ppsci .arch .MLP (** cfg .MODEL )
260
+
261
+ # initialize solver
262
+ solver = ppsci .solver .Solver (
263
+ model ,
264
+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
265
+ )
266
+ # export model
267
+ from paddle .static import InputSpec
268
+
269
+ input_spec = [
270
+ {
271
+ key : InputSpec ([None , 1 ], "float32" , name = key )
272
+ for key in cfg .MODEL .input_keys
273
+ },
274
+ ]
275
+ solver .export (input_spec , cfg .INFER .export_path )
276
+
277
+
278
+ def inference (cfg : DictConfig ):
279
+ from deploy .python_infer import pinn_predictor
280
+
281
+ predictor = pinn_predictor .PINNPredictor (cfg )
282
+
283
+ # set geometry
284
+ geom = {"timedomain" : ppsci .geometry .TimeDomain (* cfg .BOUNDS )}
285
+
286
+ input_data = geom ["timedomain" ].uniform_points (cfg .EVAL .npoint_eval )
287
+ input_dict = {"x" : input_data }
288
+
289
+ output_dict = predictor .predict (
290
+ {key : input_dict [key ] for key in cfg .MODEL .input_keys }, cfg .INFER .batch_size
291
+ )
292
+
293
+ # mapping data to cfg.INFER.output_keys
294
+ output_dict = {
295
+ store_key : output_dict [infer_key ]
296
+ for store_key , infer_key in zip (cfg .MODEL .output_keys , output_dict .keys ())
297
+ }
298
+
299
+ def u_solution_func (in_ ) -> np .ndarray :
300
+ if isinstance (in_ ["x" ], paddle .Tensor ):
301
+ return paddle .exp (- in_ ["x" ]) * paddle .cosh (in_ ["x" ])
302
+ return np .exp (- in_ ["x" ]) * np .cosh (in_ ["x" ])
303
+
304
+ label_data = u_solution_func ({"x" : input_data })
305
+ output_data = output_dict ["u" ]
306
+
307
+ # save result
308
+ plt .plot (input_data , label_data , "-" , label = r"$u(t)$" )
309
+ plt .plot (input_data , output_data , "o" , label = r"$\hat{u}(t)$" , markersize = 4.0 )
310
+ plt .legend ()
311
+ plt .xlabel (r"$t$" )
312
+ plt .ylabel (r"$u$" )
313
+ plt .title (r"$u-t$" )
314
+ plt .savefig ("./Volterra_IDE_pred.png" , dpi = 200 )
315
+
316
+
257
317
@hydra .main (version_base = None , config_path = "./conf" , config_name = "volterra_ide.yaml" )
258
318
def main (cfg : DictConfig ):
259
319
if cfg .mode == "train" :
260
320
train (cfg )
261
321
elif cfg .mode == "eval" :
262
322
evaluate (cfg )
323
+ elif cfg .mode == "export" :
324
+ export (cfg )
325
+ elif cfg .mode == "infer" :
326
+ inference (cfg )
263
327
else :
264
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
328
+ raise ValueError (
329
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
330
+ )
265
331
266
332
267
333
if __name__ == "__main__" :
0 commit comments