Skip to content

Commit c56d902

Browse files
authored
return saved targets' name list (#16240)
1 parent 996a747 commit c56d902

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/paddle/fluid/io.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def save_inference_model(dirname,
895895
True is supported.
896896
897897
Returns:
898-
None
898+
target_var_name_list(list): The fetch variables' name list
899899
900900
Raises:
901901
ValueError: If `feed_var_names` is not a list of basestring.
@@ -954,6 +954,7 @@ def save_inference_model(dirname,
954954
var, 1., name="save_infer_model/scale_{}".format(i))
955955
uniq_target_vars.append(var)
956956
target_vars = uniq_target_vars
957+
target_var_name_list = [var.name for var in target_vars]
957958

958959
# when a pserver and a trainer running on the same machine, mkdir may conflict
959960
try:
@@ -1010,6 +1011,7 @@ def save_inference_model(dirname,
10101011
params_filename = os.path.basename(params_filename)
10111012

10121013
save_persistables(executor, dirname, main_program, params_filename)
1014+
return target_var_name_list
10131015

10141016

10151017
def load_inference_model(dirname,

0 commit comments

Comments
 (0)