Skip to content

Add optimizer save and load #16986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 6, 2019
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
222700d
save optimizer related vars in dygraph
JiabinYang Apr 19, 2019
31fa880
save optimizer related vars in dygraph
JiabinYang Apr 19, 2019
8f3b7a3
test=develop, add optimizer save and load
JiabinYang Apr 19, 2019
f97cd9d
test=develop, add optimizer save and load
JiabinYang Apr 19, 2019
dda6302
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Apr 19, 2019
9d3cd10
test=develop, merge code and add multi-optimizer save and load
JiabinYang Apr 19, 2019
b887881
test=develop, merge code and add multi-optimizer save and load
JiabinYang Apr 19, 2019
b5f971a
test=develop, merge branch 'develop' of https://github.com/PaddlePadd…
JiabinYang Apr 19, 2019
a8e8325
test=develop, fix test_imperative_checkpoint
JiabinYang Apr 22, 2019
bab9583
test=develop, fix include error
JiabinYang Apr 22, 2019
495b3bc
test=develop, fix include error
JiabinYang Apr 22, 2019
3a1a16c
test=develop, renew api spec
JiabinYang Apr 23, 2019
d69205c
test=develop, refine code
JiabinYang Apr 24, 2019
a4f3fee
test=develop, set default value for checkpoint
JiabinYang Apr 25, 2019
6c34937
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang May 6, 2019
4700a29
test=develop, fix ci error
JiabinYang May 8, 2019
e5ecc65
test=develop, merge develop code
JiabinYang Jun 3, 2019
6c89fd1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Jun 4, 2019
1cf0bde
test=develop, change API.spec and make api more readable
JiabinYang Jun 4, 2019
eb40601
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Jun 4, 2019
35bb46e
test=develop, change API.spec and make api more readable
JiabinYang Jun 4, 2019
c19458a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Jun 5, 2019
335540b
test=develop, refine version and time stamp
JiabinYang Jun 5, 2019
53042d9
test=develop, add example code and refine code
JiabinYang Jun 5, 2019
84d4322
test=develop, refine doc
JiabinYang Jun 5, 2019
f91f443
test=develop, change version
JiabinYang Jun 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/fluid/API.spec

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def to_variable(value, block=None, name=None):
return py_var
elif isinstance(value, framework.Variable):
return value
else:
raise TypeError(
"to_variable only accepts 'ndarray' and 'Variable' as value's input")
71 changes: 60 additions & 11 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

import os
import collections
from .. import core
from ..framework import Variable, default_main_program
import pickle
from . import learning_rate_scheduler
import warnings

__all__ = ['save_persistables', 'load_persistables']


def save_persistables(vardict, dirname, filename=None):
def save_persistables(model_dict,
optimizer=None,
dirname='save_dir',
filename=None):
"""
This function filters out all variables in layer.parameters from the
give `layer` and then trys to load these variables from the folder
Expand All @@ -34,12 +39,12 @@ def save_persistables(vardict, dirname, filename=None):
the file name.

Args:
vardict(dict of Parameters): The parameters will
model_dict(dict of Parameters): The parameters will
be saved. If it is None, nothing
will be deal.
dirname(str): The directory path.
filename(str|None): The file which saved all variables. If variables were
saved in differnet files, set it to None.
saved in different files, set it to None.
Default: None

Returns:
Expand Down Expand Up @@ -71,11 +76,11 @@ def save_persistables(vardict, dirname, filename=None):
fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path,
layer=ptb_model)
"""
if isinstance(vardict, collections.OrderedDict):
_save_var_to_file(vardict, dirname, filename)
if isinstance(model_dict, collections.OrderedDict):
_save_var_to_file(model_dict, optimizer, dirname, filename)


def load_persistables(dirname):
def load_persistables(dirname='save_dir'):
"""
This function trys to load persistable variables from the folder
`dirname` or the file `filename`.
Expand All @@ -86,7 +91,8 @@ def load_persistables(dirname):
the file name.

Args:
dirname(str): The directory path.
dirname(str): The directory path. default is save_dir
optimizer(Optimizer): Optimizer to be save
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be saved?


Returns:
dict: The parameter-dict resumed from file
Expand All @@ -103,7 +109,7 @@ def load_persistables(dirname):
return _load_var_from_file(dirname)


def _save_var_to_file(stat_dict, file_dir, file_name):
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
save_block = default_main_program().global_block()
save_var_map = {}
for var_key, each_var in stat_dict.items():
Expand All @@ -117,6 +123,32 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name))
})
if isinstance(optimizers, (list, tuple)):
optimizers = optimizers
else:
optimizers = [optimizers]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use list(obj) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the benefit for this?

if os.path.exists(os.path.join(file_dir, os.path.normpath("optimizers"))):
pass
else:
os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
for optimizer in optimizers:
if isinstance(optimizer._learning_rate,
learning_rate_scheduler.LearningRateDecay):
try:
f = open(
os.path.join(file_dir, "optimizers",
os.path.normpath(str(optimizer._name))), "wb")
pickle.dump(optimizer._learning_rate, f, 2)
f.close()
except ():
raise IOError("Can't load %s",
os.path.join(
file_dir, "optimizers",
os.path.normpath(str(optimizer._name))))
else:
warnings.warn(
"Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
)

if file_name is not None:
save_var_list = []
Expand All @@ -138,6 +170,8 @@ def walk_filename(file_dir):
var_name_list = []
if os.path.exists(base_path):
for dirpath, dirnames, filenames in os.walk(base_path):
if "optimizers" in dirpath:
continue
pt = dirpath.replace(base_path, "", 1)
if pt.startswith("/") or pt.startswith("\\"):
pt = pt[1:]
Expand All @@ -152,6 +186,7 @@ def walk_filename(file_dir):

load_block = default_main_program().global_block()
load_var_map = {}
load_optimizer_map = {}
file_var_list = walk_filename(file_dir)
for var_name in file_var_list:
new_var = Variable(block=load_block, name=var_name)
Expand All @@ -165,8 +200,22 @@ def walk_filename(file_dir):
})

load_var_map[new_var.name] = new_var

return load_var_map
opt_path = os.path.join(file_dir, "optimizers")
for _, _, optimizers in os.walk(opt_path):
for optimizer in optimizers:
try:
f = open(os.path.join(opt_path, optimizer), "rb")
load_optimizer_map[optimizer] = pickle.load(f)
f.close()
except IOError:
raise IOError("Can't load %s",
os.path.join(
file_dir, "optimizers",
os.path.normpath(str(optimizer._name))))
if len(load_optimizer_map) == 0:
warnings.warn("No optimizer loaded")

return load_var_map, load_optimizer_map


def _clone_var_in_block_(block, var):
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/dygraph/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def __init__(self, boundaries, values, begin, step=1, dtype='float32'):

self.vars = []
for value in values:
self.vars.append(self.create_lr_var(value))
self.vars.append(value)

def step(self):
for i in range(len(self.boundaries)):
if self.step_num < self.boundaries[i]:
return self.vars[i]
return self.vars[len(self.values) - 1]
return self.create_lr_var(self.vars[len(self.values) - 1])


class NaturalExpDecay(LearningRateDecay):
Expand Down
31 changes: 24 additions & 7 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,25 +16,25 @@

import numpy as np
from collections import defaultdict
from functools import reduce

from paddle.fluid import core
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program
from paddle.fluid.layers import tensor

from . import framework
from . import layers
from . import unique_name
from .backward import append_backward
from .clip import append_gradient_clip_ops, error_clip_callback
from .dygraph import base as imperative_base
from .dygraph.learning_rate_scheduler import LearningRateDecay
from .framework import program_guard
from .initializer import Constant
from .layer_helper import LayerHelper
from .layers import ops
from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base
from .dygraph.learning_rate_scheduler import LearningRateDecay
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import those?

from .wrapped_decorator import signature_safe_contextmanager

__all__ = [
Expand Down Expand Up @@ -63,14 +63,18 @@ def __init__(self, learning_rate, regularization=None, name=None):
raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here"
% type(learning_rate))
if name is not None:
self._name = unique_name.generate(name)
else:
self._name = unique_name.generate(self.__class__.__name__)
else:
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError(
"learning rate should be float or Variable, got %s here" %
type(learning_rate))
self._name = name

self._name = name
self.regularization = regularization
self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
Expand All @@ -89,6 +93,19 @@ def __init__(self, learning_rate, regularization=None, name=None):
self.helper = None
self._opti_name_list = []

def load(self, stat_dict):
"""
load optimizer with learning rate decay in dygraph mode
:return: None

Args:
stat_dict: the dict load by load_persistable method
"""
if framework.in_dygraph_mode():
self._learning_rate = stat_dict[self._name]
else:
raise TypeError("load_dict can only be used under DyGraph mode")

def get_opti_var_name_list(self):
return self._opti_name_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def test_save_load_persistables(self):

avg_loss.backward()
sgd.minimize(avg_loss)
fluid.dygraph.save_persistables(mnist.state_dict(),
fluid.dygraph.save_persistables(mnist.state_dict(), [sgd],
"save_dir")
mnist.clear_gradients()

for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()

restore = fluid.dygraph.load_persistables("save_dir")
restore, _ = fluid.dygraph.load_persistables("save_dir")
mnist.load_dict(restore)

self.assertEqual(len(dy_param_init_value), len(restore))
Expand Down
Loading