@@ -105,28 +105,9 @@ def train(self, use_prim, use_cinn):
105
105
106
106
return res
107
107
108
- def test_cinn (self ):
109
- paddle .disable_static ()
110
- use_cinn = True
111
- if isinstance (
112
- framework ._current_expected_place (), framework .core .CPUPlace
113
- ):
114
- # TODO(jiabin): CINN will crashed in this case open it when fixed
115
- use_cinn = False
116
-
117
- dy_res = self .train (use_prim = False , use_cinn = False )
118
- comp_st_cinn_res = self .train (use_prim = True , use_cinn = use_cinn )
119
-
120
- for i in range (len (dy_res )):
121
- np .testing .assert_allclose (
122
- comp_st_cinn_res [i ].numpy (),
123
- dy_res [i ].numpy (),
124
- rtol = 1e-7 ,
125
- atol = 1e-7 ,
126
- )
108
+ def test_reshape_grad_comp (self ):
127
109
paddle .enable_static ()
128
110
129
- def test_reshape_grad_comp (self ):
130
111
def actual (primal , shape , cotangent ):
131
112
core ._set_prim_backward_enabled (True )
132
113
mp , sp = paddle .static .Program (), paddle .static .Program ()
@@ -143,7 +124,7 @@ def actual(primal, shape, cotangent):
143
124
return exe .run (
144
125
program = mp ,
145
126
feed = {'primal' : primal , 'cotangent' : cotangent },
146
- fetch_list = [x_cotangent [0 ]. name ],
127
+ fetch_list = [x_cotangent [0 ]],
147
128
)[0 ]
148
129
149
130
def desired (primal , shape , cotangent ):
@@ -162,7 +143,7 @@ def desired(primal, shape, cotangent):
162
143
return exe .run (
163
144
program = mp ,
164
145
feed = {'primal' : primal , 'cotangent' : cotangent },
165
- fetch_list = [x_cotangent [0 ]. name ],
146
+ fetch_list = [x_cotangent [0 ]],
166
147
)[0 ]
167
148
168
149
if (self .dtype == np .float16 ) and isinstance (
@@ -178,6 +159,7 @@ def desired(primal, shape, cotangent):
178
159
atol = self .rtol ,
179
160
)
180
161
core ._set_prim_backward_enabled (False )
162
+ paddle .disable_static ()
181
163
182
164
183
165
if __name__ == '__main__' :
0 commit comments