Skip to content

Commit b5f8784

Browse files
authored
Refine Model of high level API (#25559)
* Refine Model 1. Take the network (instance of Layer) as the input of Model. 2. Refine set_dict/load_dict of Layer. 3. Refine Input interface, so update code sample about Input
1 parent 4152d39 commit b5f8784

24 files changed

+619
-518
lines changed

python/paddle/fluid/dygraph/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def save_dygraph(state_dict, model_path):
9494
pickle.dump(model_dict, f, protocol=2)
9595

9696

97-
@dygraph_only
97+
# TODO(qingqing01): remove dygraph_only to support loading static model.
98+
# maybe need to unify the loading interface after 2.0 API is ready.
99+
#@dygraph_only
98100
def load_dygraph(model_path, keep_name_table=False):
99101
'''
100102
:api_attr: imperative

python/paddle/fluid/dygraph/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616
import contextlib
1717
import sys
1818
import numpy as np
19-
import collections
2019
import six
2120
import re
21+
import copy
22+
import weakref
23+
import warnings
24+
2225
from . import parallel_helper
2326
from .. import unique_name
2427
from paddle.fluid import core
2528
from .layer_object_helper import LayerObjectHelper
2629
from .base import program_desc_tracing_guard, param_guard
2730
from paddle.fluid import framework
2831
from ..param_attr import ParamAttr
29-
import copy
30-
import weakref
31-
import warnings
3232

3333
__all__ = ['Layer']
3434

python/paddle/fluid/framework.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
_dygraph_tracer_ = None
6767
_dygraph_current_expected_place_ = None
6868
_current_device = None
69-
7069
global_prog_seed = 0
7170

7271

python/paddle/incubate/hapi/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,22 @@
1616
from . import progressbar
1717
from . import callbacks
1818
from . import download
19+
1920
from . import model
21+
from .model import *
22+
2023
from . import metrics
2124
from . import loss
2225
from . import datasets
2326
from . import distributed
2427
from . import vision
2528
from . import text
2629

30+
from . import device
31+
from .device import *
32+
33+
from .dygraph_layer_patch import monkey_patch_layer
34+
2735
logger.setup_logger()
2836

2937
__all__ = [
@@ -35,6 +43,6 @@
3543
'loss',
3644
'vision',
3745
'text',
38-
]
46+
] + model.__all__ + device.__all__
3947

40-
__all__ += model.__all__
48+
monkey_patch_layer()

python/paddle/incubate/hapi/callbacks.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -291,30 +291,22 @@ class ProgBarLogger(Callback):
291291
Examples:
292292
.. code-block:: python
293293
294-
import numpy as np
295-
from paddle import fluid
296-
from paddle.incubate.hapi.metrics import Accuracy
297-
from paddle.incubate.hapi.loss import CrossEntropy
298-
from paddle.incubate.hapi.datasets import MNIST
299-
from paddle.incubate.hapi.vision.models import LeNet
300-
from paddle.incubate.hapi.callbacks import ProgBarLogger
301-
from paddle.incubate.hapi.model import Input, set_device
294+
import paddle.fluid as fluid
295+
import paddle.incubate.hapi as hapi
302296
303-
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
304-
labels = [Input([None, 1], 'int64', name='label')]
297+
inputs = [hapi.Input('image', [-1, 1, 28, 28], 'float32')]
298+
labels = [hapi.Input('label', [None, 1], 'int64')]
305299
306-
train_dataset = MNIST(mode='train')
300+
train_dataset = hapi.datasets.MNIST(mode='train')
307301
308-
model = LeNet()
302+
model = hapi.Model(hapi.vision.LeNet(), inputs, labels)
309303
310304
optim = fluid.optimizer.Adam(0.001)
311-
model.prepare(optimizer=optim,
312-
loss_function=CrossEntropy(),
313-
metrics=Accuracy(),
314-
inputs=inputs,
315-
labels=labels)
305+
model.prepare(optimizer=optim,
306+
loss_function=hapi.loss.CrossEntropy(),
307+
metrics=hapi.metrics.Accuracy())
316308
317-
callback = ProgBarLogger(log_freq=10)
309+
callback = hapi.callbacks.ProgBarLogger(log_freq=10)
318310
model.fit(train_dataset, batch_size=64, callbacks=callback)
319311
"""
320312

@@ -433,31 +425,22 @@ class ModelCheckpoint(Callback):
433425
Examples:
434426
.. code-block:: python
435427
436-
import numpy as np
437-
from paddle import fluid
438-
from paddle.incubate.hapi.metrics import Accuracy
439-
from paddle.incubate.hapi.loss import CrossEntropy
440-
from paddle.incubate.hapi.datasets import MNIST
441-
442-
from paddle.incubate.hapi.vision.models import LeNet
443-
from paddle.incubate.hapi.callbacks import ModelCheckpoint
444-
from paddle.incubate.hapi.model import Input, set_device
428+
import paddle.fluid as fluid
429+
import paddle.incubate.hapi as hapi
445430
446-
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
447-
labels = [Input([None, 1], 'int64', name='label')]
431+
inputs = [hapi.Input('image', [-1, 1, 28, 28], 'float32')]
432+
labels = [hapi.Input('label', [None, 1], 'int64')]
448433
449-
train_dataset = MNIST(mode='train')
434+
train_dataset = hapi.datasets.MNIST(mode='train')
450435
451-
model = LeNet()
436+
model = hapi.Model(hapi.vision.LeNet(), inputs, labels)
452437
453438
optim = fluid.optimizer.Adam(0.001)
454-
model.prepare(optimizer=optim,
455-
loss_function=CrossEntropy(),
456-
metrics=Accuracy(),
457-
inputs=inputs,
458-
labels=labels)
439+
model.prepare(optimizer=optim,
440+
loss_function=hapi.loss.CrossEntropy(),
441+
metrics=hapi.metrics.Accuracy())
459442
460-
callback = ModelCheckpoint(save_dir='./temp')
443+
callback = hapi.callbacks.ModelCheckpoint(save_dir='./temp')
461444
model.fit(train_dataset, batch_size=64, callbacks=callback)
462445
"""
463446

python/paddle/incubate/hapi/device.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import six
16+
17+
import paddle.fluid as fluid
18+
from paddle.fluid.dygraph.parallel import ParallelEnv
19+
20+
__all__ = ['set_device', ]
21+
22+
# TODO(qingqing01): remove or refine _global_device, set_device and get_device
23+
# after core framework supporting these function.
24+
_global_device = None
25+
26+
27+
def set_device(device):
28+
"""
29+
Args:
30+
device (str): specify device type, 'cpu' or 'gpu'.
31+
32+
Returns:
33+
fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place.
34+
35+
Examples:
36+
.. code-block:: python
37+
38+
import paddle.incubate.hapi as hapi
39+
40+
input = hapi.set_device('gpu')
41+
"""
42+
43+
assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \
44+
"Expected device in ['cpu', 'gpu'], but got {}".format(device)
45+
46+
device = fluid.CUDAPlace(ParallelEnv().dev_id) \
47+
if device.lower() == 'gpu' and fluid.is_compiled_with_cuda() \
48+
else fluid.CPUPlace()
49+
50+
global _global_device
51+
_global_device = device
52+
return device
53+
54+
55+
def _get_device():
56+
"""
57+
Return global device.
58+
"""
59+
if _global_device is not None:
60+
device = _global_device
61+
else:
62+
if fluid.is_compiled_with_cuda():
63+
device = fluid.CUDAPlace(ParallelEnv().dev_id)
64+
else:
65+
device = fluid.CPUPlace()
66+
return device
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
17+
import paddle.fluid as fluid
18+
from paddle.fluid.framework import in_dygraph_mode
19+
20+
from .device import _get_device
21+
22+
23+
def monkey_patch_layer():
24+
def load_dict(self,
25+
stat_dict,
26+
include_sublayers=True,
27+
use_structured_name=True):
28+
'''
29+
Set parameters from stat_dict. All the parameters will be reset by the
30+
tensor in the stat_dict
31+
32+
This api will be Deprecated. Please use set_dict
33+
34+
Parameters:
35+
state_dict(dict) : Dict contains all the parameters
36+
include_sublayers(bool, optional) : If true, also include the
37+
parameters from sublayers. Default: True
38+
use_structured_name(bool, optional) : If true, use structured name
39+
as key, otherwise, use parameter name as key. Default: True
40+
Returns:
41+
None
42+
43+
Examples:
44+
.. code-block:: python
45+
46+
import paddle.fluid as fluid
47+
with fluid.dygraph.guard():
48+
emb = fluid.dygraph.Embedding([10, 10])
49+
50+
state_dict = emb.state_dict()
51+
fluid.save_dygraph( state_dict, "paddle_dy")
52+
53+
para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
54+
emb.load_dict( para_state_dict )
55+
56+
'''
57+
58+
def _check_match(key, param):
59+
state = stat_dict.get(key, None)
60+
if state is None:
61+
raise ValueError(
62+
"{} is not found in the providing file.".format(key))
63+
if list(state.shape) != list(param.shape):
64+
raise ValueError(
65+
"{} receives a shape {}, but the expected shape is {}.".
66+
format(key, list(state.shape), list(param.shape)))
67+
return param, state
68+
69+
matched_param_state = []
70+
for key, param in self.state_dict().items():
71+
key_name = key if use_structured_name else param.name
72+
try:
73+
match_res = _check_match(key_name, param)
74+
except ValueError as err:
75+
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
76+
matched_param_state.append(match_res)
77+
78+
if in_dygraph_mode():
79+
for param, state in matched_param_state:
80+
param.set_value(state)
81+
else:
82+
83+
def _set_var(var, ndarray):
84+
t = fluid.global_scope().find_var(var.name).get_tensor()
85+
p = t._place()
86+
if p.is_cpu_place():
87+
place = fluid.CPUPlace()
88+
elif p.is_cuda_pinned_place():
89+
place = fluid.CUDAPinnedPlace()
90+
else:
91+
p = fluid.core.Place()
92+
p.set_place(t._place())
93+
place = fluid.CUDAPlace(p.gpu_device_id())
94+
t.set(ndarray, place)
95+
96+
executor = fluid.Executor(_get_device())._default_executor
97+
# restore parameter states
98+
fluid.core._create_loaded_parameter(
99+
[param for param, state in matched_param_state],
100+
fluid.global_scope(), executor)
101+
for param, state in matched_param_state:
102+
_set_var(param, state)
103+
104+
setattr(fluid.dygraph.Layer, 'load_dict', load_dict)

python/paddle/incubate/hapi/loss.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,13 @@ class CrossEntropy(Loss):
8686
Examples:
8787
.. code-block:: python
8888
89-
from paddle.incubate.hapi.model import Input
90-
from paddle.incubate.hapi.vision.models import LeNet
91-
from paddle.incubate.hapi.loss import CrossEntropy
89+
import paddle.fluid as fluid
90+
import paddle.incubate.hapi as hapi
9291
93-
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
94-
labels = [Input([None, 1], 'int64', name='label')]
92+
fluid.enable_dygraph()
9593
96-
model = LeNet()
97-
loss = CrossEntropy()
98-
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
94+
model = hapi.Model(hapi.vision.LeNet())
95+
model.prepare(loss_function=hapi.loss.CrossEntropy())
9996
10097
"""
10198

@@ -123,16 +120,14 @@ class SoftmaxWithCrossEntropy(Loss):
123120
Examples:
124121
.. code-block:: python
125122
126-
from paddle.incubate.hapi.model import Input
127-
from paddle.incubate.hapi.vision.models import LeNet
128-
from paddle.incubate.hapi.loss import SoftmaxWithCrossEntropy
123+
import paddle.fluid as fluid
124+
import paddle.incubate.hapi as hapi
129125
130-
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
131-
labels = [Input([None, 1], 'int64', name='label')]
126+
fluid.enable_dygraph()
132127
133-
model = LeNet(classifier_activation=None)
134-
loss = SoftmaxWithCrossEntropy()
135-
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
128+
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None))
129+
loss = hapi.loss.SoftmaxWithCrossEntropy()
130+
model.prepare(loss_function=loss)
136131
"""
137132

138133
def __init__(self, average=True):

0 commit comments

Comments
 (0)