Skip to content

Commit b1f0d00

Browse files
authored
[Serde][Dy2St] Use cached PartialProgramLayer in jit.load (PaddlePaddle#72103)
1 parent c826dc8 commit b1f0d00

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

python/paddle/jit/pir_translated_layer.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _construct_params_and_buffers(model_path, programs, params_filename=None):
321321
return var_dict
322322

323323

324-
def _run_dygraph(instance, input, program_holder):
324+
def _run_dygraph(instance, input, program_holder, method_name):
325325
# 1. prepare inputs, outputs, attrs
326326
input_tensors = []
327327
input_tensor_names = []
@@ -348,35 +348,37 @@ def _run_dygraph(instance, input, program_holder):
348348
input_tensor_names.append(tensor.name)
349349
input_tensors.append(tensor)
350350

351-
persistable_tensors = []
352-
origin_persistable_var_name = [
353-
program_holder._suffix_varname_dict[var_name]
354-
for var_name in program_holder.persistable_names
355-
]
356-
for var_name in origin_persistable_var_name:
357-
dy_var_name = instance._persistable_var_name_dict[var_name]
358-
if dy_var_name in instance._parameters:
359-
persistable_tensors.append(instance._parameters[dy_var_name])
360-
elif dy_var_name in instance._buffers:
361-
persistable_tensors.append(instance._buffers[dy_var_name])
362-
else:
363-
raise ValueError(
364-
f"The persistable variable {var_name} does not exist in current PirTranslatedLayer."
365-
)
351+
if instance._get_partial_program_layer(method_name) is None:
352+
persistable_tensors = []
353+
origin_persistable_var_name = [
354+
program_holder._suffix_varname_dict[var_name]
355+
for var_name in program_holder.persistable_names
356+
]
357+
for var_name in origin_persistable_var_name:
358+
dy_var_name = instance._persistable_var_name_dict[var_name]
359+
if dy_var_name in instance._parameters:
360+
persistable_tensors.append(instance._parameters[dy_var_name])
361+
elif dy_var_name in instance._buffers:
362+
persistable_tensors.append(instance._buffers[dy_var_name])
363+
else:
364+
raise ValueError(
365+
f"The persistable variable {var_name} does not exist in current PirTranslatedLayer."
366+
)
366367

367-
from paddle.jit.dy2static.pir_partial_program import PartialProgramLayer
368+
from paddle.jit.dy2static.pir_partial_program import PartialProgramLayer
368369

369-
inputs = program_holder.input_vars
370-
outputs = program_holder.output_vars
371-
parameters = (persistable_tensors, program_holder.persistable_vars)
370+
inputs = program_holder.input_vars
371+
outputs = program_holder.output_vars
372+
parameters = (persistable_tensors, program_holder.persistable_vars)
372373

373-
layer = PartialProgramLayer(
374-
program_holder.infer_program,
375-
inputs,
376-
outputs,
377-
parameters,
378-
)
379-
instance.layer = layer
374+
layer = PartialProgramLayer(
375+
program_holder.infer_program,
376+
inputs,
377+
outputs,
378+
parameters,
379+
)
380+
instance._set_partial_program_layer(method_name, layer)
381+
layer = instance._get_partial_program_layer(method_name)
380382
if instance._is_test:
381383
layer.training = False
382384
else:
@@ -387,7 +389,7 @@ def _run_dygraph(instance, input, program_holder):
387389
else:
388390
layer.training = True
389391

390-
return instance.layer(input_tensors)
392+
return layer(input_tensors)
391393

392394

393395
def _run_static_graph(inputs, program_holder, src_program):
@@ -589,6 +591,7 @@ def __init__(
589591

590592
self._is_test = True
591593
self._input_args_names = None
594+
self._partial_program_layers = {}
592595

593596
@staticmethod
594597
@framework.dygraph_only
@@ -640,7 +643,7 @@ def __i_m_p_l__(self, *input):
640643
# When using jit.save, it runs in static graph mode.
641644
# Run in dynamic graph mode when the model is inferring.
642645
if in_dynamic_mode():
643-
return _run_dygraph(self, input, program_holder)
646+
return _run_dygraph(self, input, program_holder, method_name)
644647
else:
645648
return _run_static_graph(
646649
input, program_holder, program_holder.infer_program
@@ -792,3 +795,9 @@ def _output_spec(self, method_name='forward'):
792795
output_spec.append(spec)
793796

794797
return output_spec
798+
799+
def _get_partial_program_layer(self, method_name):
800+
return self._partial_program_layers.get(method_name, None)
801+
802+
def _set_partial_program_layer(self, method_name, layer):
803+
self._partial_program_layers[method_name] = layer

test/legacy_test/test_jit_layer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
import tempfile
1818
import unittest
19+
from pathlib import Path
1920

2021
import numpy as np
2122

@@ -24,7 +25,9 @@
2425
from paddle.jit.layer import Layer
2526
from paddle.static import InputSpec
2627

27-
sys.path.append("../dygraph_to_static")
28+
sys.path.append(
29+
str(Path(__file__).resolve().parent.parent / "dygraph_to_static")
30+
)
2831
from dygraph_to_static_utils import enable_to_static_guard
2932

3033
paddle.seed(1)

0 commit comments

Comments
 (0)