@@ -321,7 +321,7 @@ def _construct_params_and_buffers(model_path, programs, params_filename=None):
321
321
return var_dict
322
322
323
323
324
- def _run_dygraph (instance , input , program_holder ):
324
+ def _run_dygraph (instance , input , program_holder , method_name ):
325
325
# 1. prepare inputs, outputs, attrs
326
326
input_tensors = []
327
327
input_tensor_names = []
@@ -348,35 +348,37 @@ def _run_dygraph(instance, input, program_holder):
348
348
input_tensor_names .append (tensor .name )
349
349
input_tensors .append (tensor )
350
350
351
- persistable_tensors = []
352
- origin_persistable_var_name = [
353
- program_holder ._suffix_varname_dict [var_name ]
354
- for var_name in program_holder .persistable_names
355
- ]
356
- for var_name in origin_persistable_var_name :
357
- dy_var_name = instance ._persistable_var_name_dict [var_name ]
358
- if dy_var_name in instance ._parameters :
359
- persistable_tensors .append (instance ._parameters [dy_var_name ])
360
- elif dy_var_name in instance ._buffers :
361
- persistable_tensors .append (instance ._buffers [dy_var_name ])
362
- else :
363
- raise ValueError (
364
- f"The persistable variable { var_name } does not exist in current PirTranslatedLayer."
365
- )
351
+ if instance ._get_partial_program_layer (method_name ) is None :
352
+ persistable_tensors = []
353
+ origin_persistable_var_name = [
354
+ program_holder ._suffix_varname_dict [var_name ]
355
+ for var_name in program_holder .persistable_names
356
+ ]
357
+ for var_name in origin_persistable_var_name :
358
+ dy_var_name = instance ._persistable_var_name_dict [var_name ]
359
+ if dy_var_name in instance ._parameters :
360
+ persistable_tensors .append (instance ._parameters [dy_var_name ])
361
+ elif dy_var_name in instance ._buffers :
362
+ persistable_tensors .append (instance ._buffers [dy_var_name ])
363
+ else :
364
+ raise ValueError (
365
+ f"The persistable variable { var_name } does not exist in current PirTranslatedLayer."
366
+ )
366
367
367
- from paddle .jit .dy2static .pir_partial_program import PartialProgramLayer
368
+ from paddle .jit .dy2static .pir_partial_program import PartialProgramLayer
368
369
369
- inputs = program_holder .input_vars
370
- outputs = program_holder .output_vars
371
- parameters = (persistable_tensors , program_holder .persistable_vars )
370
+ inputs = program_holder .input_vars
371
+ outputs = program_holder .output_vars
372
+ parameters = (persistable_tensors , program_holder .persistable_vars )
372
373
373
- layer = PartialProgramLayer (
374
- program_holder .infer_program ,
375
- inputs ,
376
- outputs ,
377
- parameters ,
378
- )
379
- instance .layer = layer
374
+ layer = PartialProgramLayer (
375
+ program_holder .infer_program ,
376
+ inputs ,
377
+ outputs ,
378
+ parameters ,
379
+ )
380
+ instance ._set_partial_program_layer (method_name , layer )
381
+ layer = instance ._get_partial_program_layer (method_name )
380
382
if instance ._is_test :
381
383
layer .training = False
382
384
else :
@@ -387,7 +389,7 @@ def _run_dygraph(instance, input, program_holder):
387
389
else :
388
390
layer .training = True
389
391
390
- return instance . layer (input_tensors )
392
+ return layer (input_tensors )
391
393
392
394
393
395
def _run_static_graph (inputs , program_holder , src_program ):
@@ -589,6 +591,7 @@ def __init__(
589
591
590
592
self ._is_test = True
591
593
self ._input_args_names = None
594
+ self ._partial_program_layers = {}
592
595
593
596
@staticmethod
594
597
@framework .dygraph_only
@@ -640,7 +643,7 @@ def __i_m_p_l__(self, *input):
640
643
# When using jit.save, it runs in static graph mode.
641
644
# Run in dynamic graph mode when the model is inferring.
642
645
if in_dynamic_mode ():
643
- return _run_dygraph (self , input , program_holder )
646
+ return _run_dygraph (self , input , program_holder , method_name )
644
647
else :
645
648
return _run_static_graph (
646
649
input , program_holder , program_holder .infer_program
@@ -792,3 +795,9 @@ def _output_spec(self, method_name='forward'):
792
795
output_spec .append (spec )
793
796
794
797
return output_spec
798
+
799
+ def _get_partial_program_layer (self , method_name ):
800
+ return self ._partial_program_layers .get (method_name , None )
801
+
802
+ def _set_partial_program_layer (self , method_name , layer ):
803
+ self ._partial_program_layers [method_name ] = layer
0 commit comments