Skip to content

Commit f096e71

Browse files
committed
Make param_attr as a strong typed class
Fix PaddlePaddle#5819
1 parent 91bfb07 commit f096e71

9 files changed

+128
-145
lines changed

python/paddle/v2/fluid/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
import optimizer
1414
import backward
1515
import regularizer
16+
from param_attr import ParamAttr
1617

1718
from core import LoDTensor, CPUPlace, GPUPlace
1819

1920
Tensor = LoDTensor
2021
__all__ = framework.__all__ + executor.__all__ + [
2122
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
22-
'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor'
23+
'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr'
2324
]
2425

2526

python/paddle/v2/fluid/layer_helper.py

+28-43
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import copy
22
import itertools
33

4-
from framework import Variable, default_main_program, default_startup_program, unique_name, dtype_is_floating
4+
from framework import Variable, default_main_program, default_startup_program, \
5+
unique_name, dtype_is_floating
56
from paddle.v2.fluid.initializer import Constant, Xavier
7+
from param_attr import ParamAttr
68

79

810
class LayerHelper(object):
@@ -59,31 +61,15 @@ def input(self, input_param_name='input'):
5961

6062
@property
6163
def param_attr(self):
62-
default = {'name': None}
63-
actual = self.kwargs.get('param_attr', None)
64-
if actual is None:
65-
actual = default
66-
for default_field in default.keys():
67-
if default_field not in actual:
68-
actual[default_field] = default[default_field]
69-
return actual
64+
return ParamAttr.to_attr(self.kwargs.get('param_attr', None))
7065

7166
@property
7267
def bias_attr(self):
73-
default = {'name': None}
74-
bias_attr = self.kwargs.get('bias_attr', None)
75-
if bias_attr is None:
76-
bias_attr = default
77-
78-
if isinstance(bias_attr, dict):
79-
for default_field in default.keys():
80-
if default_field not in bias_attr:
81-
bias_attr[default_field] = default[default_field]
82-
return bias_attr
68+
return ParamAttr.to_attr(self.kwargs.get('bias_attr', None))
8369

8470
def multiple_param_attr(self, length):
8571
param_attr = self.param_attr
86-
if isinstance(param_attr, dict):
72+
if isinstance(param_attr, ParamAttr):
8773
param_attr = [param_attr]
8874

8975
if len(param_attr) != 1 and len(param_attr) != length:
@@ -111,23 +97,30 @@ def input_dtype(self, input_param_name='input'):
11197
raise ValueError("Data Type mismatch")
11298
return dtype
11399

114-
def create_parameter(self, attr, shape, dtype, suffix='w',
115-
initializer=None):
100+
def create_parameter(self,
101+
attr,
102+
shape,
103+
dtype,
104+
is_bias=False,
105+
default_initializer=None):
116106
# Deepcopy the attr so that parameters can be shared in program
117-
attr_copy = copy.deepcopy(attr)
118-
if initializer is not None:
119-
attr_copy['initializer'] = initializer
107+
assert isinstance(attr, ParamAttr)
108+
suffix = 'b' if is_bias else 'w'
109+
110+
if default_initializer is None:
111+
if is_bias:
112+
attr.set_default_bias_initializer()
113+
else:
114+
attr.set_default_param_initializer()
120115
else:
121-
attr_copy['initializer'] = self._get_default_initializer(dtype)
122-
if attr_copy['name'] is None:
123-
attr_copy['name'] = unique_name(".".join([self.name, suffix]))
116+
attr.set_default_initializer(default_initializer)
117+
if attr.name is None:
118+
attr.name = unique_name(".".join([self.name, suffix]))
119+
124120
self.startup_program.global_block().create_parameter(
125-
dtype=dtype, shape=shape, **attr_copy)
121+
dtype=dtype, shape=shape, **attr.to_kwargs(with_initializer=True))
126122
return self.main_program.global_block().create_parameter(
127-
name=attr_copy['name'],
128-
dtype=dtype,
129-
shape=shape,
130-
trainable=attr_copy.get('trainable', True))
123+
dtype=dtype, shape=shape, **attr.to_kwargs())
131124

132125
def create_tmp_variable(self, dtype):
133126
return self.main_program.current_block().create_var(
@@ -152,11 +145,7 @@ def set_variable_initializer(self, var, initializer):
152145
persistable=True,
153146
initializer=initializer)
154147

155-
def append_bias_op(self,
156-
input_var,
157-
bias_initializer,
158-
dim_start=1,
159-
dim_end=None):
148+
def append_bias_op(self, input_var, dim_start=1, dim_end=None):
160149
"""
161150
Append bias operator and return its output. If the user does not set
162151
bias_attr, append_bias_op will return input_var
@@ -176,11 +165,7 @@ def append_bias_op(self,
176165
return input_var
177166

178167
b = self.create_parameter(
179-
attr=bias_attr,
180-
shape=size,
181-
dtype=input_var.dtype,
182-
suffix='b',
183-
initializer=bias_initializer)
168+
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
184169
tmp = self.create_tmp_variable(dtype=input_var.dtype)
185170
self.append_op(
186171
type='elementwise_add',

0 commit comments

Comments
 (0)