1
1
import copy
2
2
import itertools
3
3
4
- from framework import Variable , default_main_program , default_startup_program , unique_name , dtype_is_floating
4
+ from framework import Variable , default_main_program , default_startup_program , \
5
+ unique_name , dtype_is_floating
5
6
from paddle .v2 .fluid .initializer import Constant , Xavier
7
+ from param_attr import ParamAttr
6
8
7
9
8
10
class LayerHelper (object ):
@@ -59,31 +61,15 @@ def input(self, input_param_name='input'):
59
61
60
62
@property
61
63
def param_attr (self ):
62
- default = {'name' : None }
63
- actual = self .kwargs .get ('param_attr' , None )
64
- if actual is None :
65
- actual = default
66
- for default_field in default .keys ():
67
- if default_field not in actual :
68
- actual [default_field ] = default [default_field ]
69
- return actual
64
+ return ParamAttr .to_attr (self .kwargs .get ('param_attr' , None ))
70
65
71
66
@property
72
67
def bias_attr (self ):
73
- default = {'name' : None }
74
- bias_attr = self .kwargs .get ('bias_attr' , None )
75
- if bias_attr is None :
76
- bias_attr = default
77
-
78
- if isinstance (bias_attr , dict ):
79
- for default_field in default .keys ():
80
- if default_field not in bias_attr :
81
- bias_attr [default_field ] = default [default_field ]
82
- return bias_attr
68
+ return ParamAttr .to_attr (self .kwargs .get ('bias_attr' , None ))
83
69
84
70
def multiple_param_attr (self , length ):
85
71
param_attr = self .param_attr
86
- if isinstance (param_attr , dict ):
72
+ if isinstance (param_attr , ParamAttr ):
87
73
param_attr = [param_attr ]
88
74
89
75
if len (param_attr ) != 1 and len (param_attr ) != length :
@@ -111,23 +97,30 @@ def input_dtype(self, input_param_name='input'):
111
97
raise ValueError ("Data Type mismatch" )
112
98
return dtype
113
99
114
- def create_parameter (self , attr , shape , dtype , suffix = 'w' ,
115
- initializer = None ):
100
+ def create_parameter (self ,
101
+ attr ,
102
+ shape ,
103
+ dtype ,
104
+ is_bias = False ,
105
+ default_initializer = None ):
116
106
# Deepcopy the attr so that parameters can be shared in program
117
- attr_copy = copy .deepcopy (attr )
118
- if initializer is not None :
119
- attr_copy ['initializer' ] = initializer
107
+ assert isinstance (attr , ParamAttr )
108
+ suffix = 'b' if is_bias else 'w'
109
+
110
+ if default_initializer is None :
111
+ if is_bias :
112
+ attr .set_default_bias_initializer ()
113
+ else :
114
+ attr .set_default_param_initializer ()
120
115
else :
121
- attr_copy ['initializer' ] = self ._get_default_initializer (dtype )
122
- if attr_copy ['name' ] is None :
123
- attr_copy ['name' ] = unique_name ("." .join ([self .name , suffix ]))
116
+ attr .set_default_initializer (default_initializer )
117
+ if attr .name is None :
118
+ attr .name = unique_name ("." .join ([self .name , suffix ]))
119
+
124
120
self .startup_program .global_block ().create_parameter (
125
- dtype = dtype , shape = shape , ** attr_copy )
121
+ dtype = dtype , shape = shape , ** attr . to_kwargs ( with_initializer = True ) )
126
122
return self .main_program .global_block ().create_parameter (
127
- name = attr_copy ['name' ],
128
- dtype = dtype ,
129
- shape = shape ,
130
- trainable = attr_copy .get ('trainable' , True ))
123
+ dtype = dtype , shape = shape , ** attr .to_kwargs ())
131
124
132
125
def create_tmp_variable (self , dtype ):
133
126
return self .main_program .current_block ().create_var (
@@ -152,11 +145,7 @@ def set_variable_initializer(self, var, initializer):
152
145
persistable = True ,
153
146
initializer = initializer )
154
147
155
- def append_bias_op (self ,
156
- input_var ,
157
- bias_initializer ,
158
- dim_start = 1 ,
159
- dim_end = None ):
148
+ def append_bias_op (self , input_var , dim_start = 1 , dim_end = None ):
160
149
"""
161
150
Append bias operator and return its output. If the user does not set
162
151
bias_attr, append_bias_op will return input_var
@@ -176,11 +165,7 @@ def append_bias_op(self,
176
165
return input_var
177
166
178
167
b = self .create_parameter (
179
- attr = bias_attr ,
180
- shape = size ,
181
- dtype = input_var .dtype ,
182
- suffix = 'b' ,
183
- initializer = bias_initializer )
168
+ attr = bias_attr , shape = size , dtype = input_var .dtype , is_bias = True )
184
169
tmp = self .create_tmp_variable (dtype = input_var .dtype )
185
170
self .append_op (
186
171
type = 'elementwise_add' ,
0 commit comments