@@ -223,5 +223,109 @@ def test_xavier_initializer_supplied_arguments(self):
223
223
self .assertEqual (init_op .attr ('seed' ), 134 )
224
224
225
225
226
+ class TestMSRAInitializer (unittest .TestCase ):
227
+ def test_uniform_msra_initializer (self ):
228
+ """Test MSRA initializer with uniform distribution on
229
+ for matrix multiply.
230
+ """
231
+ program = framework .Program ()
232
+ block = program .global_block ()
233
+ param = block .create_parameter (
234
+ dtype = "float32" ,
235
+ shape = [5 , 10 ],
236
+ lod_level = 0 ,
237
+ name = "param" ,
238
+ initializer = initializer .MSRAInitializer ())
239
+ self .assertEqual (len (block .ops ), 1 )
240
+ init_op = block .ops [0 ]
241
+ self .assertEqual (init_op .type , 'uniform_random' )
242
+ limit = np .sqrt (6.0 / param .shape [0 ])
243
+ self .assertAlmostEqual (init_op .attr ('min' ), - limit , delta = DELTA )
244
+ self .assertAlmostEqual (init_op .attr ('max' ), limit , delta = DELTA )
245
+ self .assertEqual (init_op .attr ('seed' ), 0 )
246
+
247
+ def test_uniform_msra_initializer_conv (self ):
248
+ """Test MSRA initializer with uniform distribution on
249
+ for convolutions.
250
+ """
251
+ program = framework .Program ()
252
+ block = program .global_block ()
253
+ param = block .create_parameter (
254
+ dtype = "float32" ,
255
+ shape = [5 , 10 , 15 , 20 ],
256
+ lod_level = 0 ,
257
+ name = "param" ,
258
+ initializer = initializer .MSRAInitializer ())
259
+ self .assertEqual (len (block .ops ), 1 )
260
+ init_op = block .ops [0 ]
261
+ self .assertEqual (init_op .type , 'uniform_random' )
262
+ receptive_field_size = float (15 * 20 )
263
+ limit = np .sqrt (6.0 / (param .shape [1 ] * receptive_field_size ))
264
+ self .assertAlmostEqual (init_op .attr ('min' ), - limit , delta = DELTA )
265
+ self .assertAlmostEqual (init_op .attr ('max' ), limit , delta = DELTA )
266
+ self .assertEqual (init_op .attr ('seed' ), 0 )
267
+
268
+ def test_normal_msra_initializer (self ):
269
+ """Test MSRA initializer with normal distribution on
270
+ for matrix multiply.
271
+ """
272
+ program = framework .Program ()
273
+ block = program .global_block ()
274
+ param = block .create_parameter (
275
+ dtype = "float32" ,
276
+ shape = [5 , 10 ],
277
+ lod_level = 0 ,
278
+ name = "param" ,
279
+ initializer = initializer .MSRAInitializer (uniform = False ))
280
+ self .assertEqual (len (block .ops ), 1 )
281
+ init_op = block .ops [0 ]
282
+ self .assertEqual (init_op .type , 'gaussian_random' )
283
+ std = np .sqrt (2.0 / param .shape [0 ])
284
+ self .assertAlmostEqual (init_op .attr ('mean' ), 0.0 , delta = DELTA )
285
+ self .assertAlmostEqual (init_op .attr ('std' ), std , delta = DELTA )
286
+ self .assertEqual (init_op .attr ('seed' ), 0 )
287
+
288
+ def test_normal_msra_initializer_conv (self ):
289
+ """Test MSRA initializer with normal distribution on
290
+ for convolutions.
291
+ """
292
+ program = framework .Program ()
293
+ block = program .global_block ()
294
+ param = block .create_parameter (
295
+ dtype = "float32" ,
296
+ shape = [5 , 10 , 15 , 20 ],
297
+ lod_level = 0 ,
298
+ name = "param" ,
299
+ initializer = initializer .MSRAInitializer (uniform = False ))
300
+ self .assertEqual (len (block .ops ), 1 )
301
+ init_op = block .ops [0 ]
302
+ self .assertEqual (init_op .type , 'gaussian_random' )
303
+ receptive_field_size = float (15 * 20 )
304
+ std = np .sqrt (2.0 / (param .shape [1 ] * receptive_field_size ))
305
+ self .assertAlmostEqual (init_op .attr ('mean' ), 0.0 , delta = DELTA )
306
+ self .assertAlmostEqual (init_op .attr ('std' ), std , delta = DELTA )
307
+ self .assertEqual (init_op .attr ('seed' ), 0 )
308
+
309
+ def test_msra_initializer_supplied_arguments (self ):
310
+ """Test the MSRA initializer with supplied arguments
311
+ """
312
+ program = framework .Program ()
313
+ block = program .global_block ()
314
+ block .create_parameter (
315
+ dtype = "float32" ,
316
+ shape = [5 , 10 ],
317
+ lod_level = 0 ,
318
+ name = "param" ,
319
+ initializer = initializer .MSRAInitializer (
320
+ fan_in = 12 , seed = 134 ))
321
+ self .assertEqual (len (block .ops ), 1 )
322
+ init_op = block .ops [0 ]
323
+ self .assertEqual (init_op .type , 'uniform_random' )
324
+ limit = np .sqrt (6.0 / 12 )
325
+ self .assertAlmostEqual (init_op .attr ('min' ), - limit , delta = DELTA )
326
+ self .assertAlmostEqual (init_op .attr ('max' ), limit , delta = DELTA )
327
+ self .assertEqual (init_op .attr ('seed' ), 134 )
328
+
329
+
226
330
if __name__ == '__main__' :
227
331
unittest .main ()
0 commit comments