diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index e6bc75b07828c..7ae59b054953b 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -716,7 +716,6 @@ set(STATIC_BUILD_TESTS test_fuse_bn_act_pass_deprecated test_layer_norm_op_deprecated test_lookup_table_v2_op_deprecated - test_momentum_op test_momentum_op_deprecated test_nce_deprecated test_sparse_conv_op diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 053b0da2fc4e5..9ea05ac57a950 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1058,6 +1058,7 @@ set(STATIC_BUILD_TESTS test_layer_norm_op test_eigh_op test_matmul_v2_op + test_momentum_op test_paddle_save_load_binary test_assign_pos_op test_bucketize_api diff --git a/test/deprecated/legacy_test/test_momentum_op.py b/test/legacy_test/test_momentum_op.py similarity index 93% rename from test/deprecated/legacy_test/test_momentum_op.py rename to test/legacy_test/test_momentum_op.py index c48601326f4bd..c8d4bd815d1ae 100644 --- a/test/deprecated/legacy_test/test_momentum_op.py +++ b/test/legacy_test/test_momentum_op.py @@ -760,59 +760,68 @@ def get_program(self, weight_attr, bias_attr=False): def test_param_has_l2decay(self): paddle.enable_static() - weight_attr = paddle.ParamAttr( - name="weight", - initializer=paddle.nn.initializer.Constant(value=0.5), - regularizer=paddle.regularizer.L2Decay(0.1), - ) - program = self.get_program(weight_attr, bias_attr=False) - ops = program.global_block().ops + with paddle.pir_utils.OldIrGuard(): + weight_attr = paddle.ParamAttr( + name="weight", + initializer=paddle.nn.initializer.Constant(value=0.5), + regularizer=paddle.regularizer.L2Decay(0.1), + ) + program = self.get_program(weight_attr, bias_attr=False) + ops = program.global_block().ops - self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') - self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.1)) - for i in range(len(ops)): - self.assertTrue('sum' not in ops[i].type) - self.assertTrue('scale' not in ops[i].type) + self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') + self.assertEqual( + ops[-1].attr('regularization_coeff'), np.float32(0.1) + ) + for i in range(len(ops)): + self.assertTrue('sum' not in ops[i].type) + self.assertTrue('scale' not in ops[i].type) def test_param_has_l1decay(self): paddle.enable_static() - weight_attr = paddle.ParamAttr( - name="weight", - initializer=paddle.nn.initializer.Constant(value=0.5), - regularizer=paddle.regularizer.L1Decay(0.1), - ) - bias_attr = paddle.ParamAttr( - name="bias", - initializer=paddle.nn.initializer.Constant(value=0.0), - regularizer=None, - ) - program = self.get_program(weight_attr, bias_attr) - ops = program.global_block().ops - - self.assertEqual(ops[-1].type, 'momentum') - self.assertEqual(ops[-2].type, 'momentum') - self.assertEqual(ops[-3].type, 'sum') - self.assertEqual(ops[-4].type, 'scale') - self.assertEqual(ops[-5].type, 'sign') - self.assertEqual(ops[-6].type, 'matmul_v2_grad') - if 'weight' in ops[-1].input('Param'): - self.assertEqual(ops[-1].attr('regularization_method'), '') - self.assertEqual(ops[-1].attr('regularization_coeff'), 0) - if 'bias' in ops[-2].input('Param'): - self.assertEqual(ops[-2].attr('regularization_method'), 'l2_decay') - self.assertEqual( - ops[-2].attr('regularization_coeff'), np.float32(0.5) + with paddle.pir_utils.OldIrGuard(): + weight_attr = paddle.ParamAttr( + name="weight", + initializer=paddle.nn.initializer.Constant(value=0.5), + regularizer=paddle.regularizer.L1Decay(0.1), ) + bias_attr = paddle.ParamAttr( + name="bias", + initializer=paddle.nn.initializer.Constant(value=0.0), + regularizer=None, + ) + program = self.get_program(weight_attr, bias_attr) + ops = program.global_block().ops + + self.assertEqual(ops[-1].type, 'momentum') + self.assertEqual(ops[-2].type, 'momentum') + self.assertEqual(ops[-3].type, 'sum') + self.assertEqual(ops[-4].type, 'scale') + self.assertEqual(ops[-5].type, 'sign') + self.assertEqual(ops[-6].type, 'matmul_v2_grad') + if 'weight' in ops[-1].input('Param'): + self.assertEqual(ops[-1].attr('regularization_method'), '') + self.assertEqual(ops[-1].attr('regularization_coeff'), 0) + if 'bias' in ops[-2].input('Param'): + self.assertEqual( + ops[-2].attr('regularization_method'), 'l2_decay' + ) + self.assertEqual( + ops[-2].attr('regularization_coeff'), np.float32(0.5) + ) def test_param_has_no_regularizer(self): paddle.enable_static() - program = self.get_program(weight_attr=None) - ops = program.global_block().ops - self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') - self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5)) - for i in range(len(ops)): - self.assertTrue('sum' not in ops[i].type) - self.assertTrue('scale' not in ops[i].type) + with paddle.pir_utils.OldIrGuard(): + program = self.get_program(weight_attr=None) + ops = program.global_block().ops + self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') + self.assertEqual( + ops[-1].attr('regularization_coeff'), np.float32(0.5) + ) + for i in range(len(ops)): + self.assertTrue('sum' not in ops[i].type) + self.assertTrue('scale' not in ops[i].type) class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):