12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import os
16
15
import unittest
17
16
18
17
import numpy as np
@@ -64,14 +63,7 @@ def test_check_output(self):
64
63
self .check_output (check_pir = True )
65
64
66
65
def test_check_grad (self ):
67
- places = []
68
- if (
69
- os .environ .get ('FLAGS_CI_both_cpu_and_gpu' , 'False' ).lower ()
70
- in ['1' , 'true' , 'on' ]
71
- or not core .is_compiled_with_cuda ()
72
- or core .is_compiled_with_rocm ()
73
- ):
74
- places .append (base .CPUPlace ())
66
+ places = [base .CPUPlace ()]
75
67
if core .is_compiled_with_cuda () and (not core .is_compiled_with_rocm ()):
76
68
places .append (base .CUDAPlace (0 ))
77
69
for p in places :
@@ -161,6 +153,11 @@ def init_config(self):
161
153
self ._input_shape = (32 , 32 )
162
154
163
155
156
+ class TestCholeskyOpZeroSize (TestCholeskyOp ):
157
+ def init_config (self ):
158
+ self ._input_shape = (0 , 0 )
159
+
160
+
164
161
class TestDygraph (unittest .TestCase ):
165
162
def test_dygraph (self ):
166
163
if core .is_compiled_with_rocm ():
@@ -176,27 +173,20 @@ def test_dygraph(self):
176
173
177
174
class TestCholeskySingularAPI (unittest .TestCase ):
178
175
def setUp (self ):
179
- self .places = []
180
- if (
181
- os .environ .get ('FLAGS_CI_both_cpu_and_gpu' , 'False' ).lower ()
182
- in ['1' , 'true' , 'on' ]
183
- or not core .is_compiled_with_cuda ()
184
- or core .is_compiled_with_rocm ()
185
- ):
186
- self .places .append (base .CPUPlace ())
176
+ self .places = [base .CPUPlace ()]
187
177
if core .is_compiled_with_cuda () and (not core .is_compiled_with_rocm ()):
188
178
self .places .append (base .CUDAPlace (0 ))
189
179
190
- def check_static_result (self , place , with_out = False ):
180
+ def check_static_result (self , place , input_shape , with_out = False ):
191
181
with paddle .static .program_guard (
192
182
paddle .static .Program (), paddle .static .Program ()
193
183
):
194
184
input = paddle .static .data (
195
- name = "input" , shape = [ 4 , 4 ] , dtype = "float64"
185
+ name = "input" , shape = input_shape , dtype = "float64"
196
186
)
197
187
result = paddle .cholesky (input )
198
188
199
- input_np = np .zeros ([ 4 , 4 ] ).astype ("float64" )
189
+ input_np = np .zeros (input_shape ).astype ("float64" )
200
190
201
191
exe = base .Executor (place )
202
192
try :
@@ -211,7 +201,9 @@ def check_static_result(self, place, with_out=False):
211
201
212
202
def test_static (self ):
213
203
for place in self .places :
214
- self .check_static_result (place = place )
204
+ self .check_static_result (place = place , input_shape = [4 , 4 ])
205
+ self .check_static_result (place = place , input_shape = [0 , 0 ])
206
+ self .check_static_result (place = place , input_shape = [5 , 0 , 0 ])
215
207
216
208
def test_dygraph (self ):
217
209
for place in self .places :
@@ -222,9 +214,12 @@ def test_dygraph(self):
222
214
[[10 , 11 , 12 ], [13 , 14 , 15 ], [16 , 17 , 18 ]],
223
215
]
224
216
).astype ("float64" )
217
+ input_np_zero = np .zeros ((0 , 3 , 3 ), dtype = "float64" )
225
218
input = paddle .to_tensor (input_np )
219
+ input_zero = paddle .to_tensor (input_np_zero )
226
220
try :
227
221
result = paddle .cholesky (input )
222
+ result_zero = paddle .cholesky (input_zero )
228
223
except RuntimeError as ex :
229
224
print ("The mat is singular" )
230
225
except ValueError as ex :
0 commit comments