10
10
import pandas as pd
11
11
import json
12
12
import subprocess
13
+ import glob
13
14
from cesium_app .model_util import create_token_user
14
15
15
16
@@ -154,41 +155,44 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
154
155
model , prediction ):
155
156
driver .get ('/' )
156
157
_click_download (project .id , driver )
157
- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
158
+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
159
+ assert len (matching_downloads_paths ) == 1
158
160
try :
159
161
npt .assert_equal (
160
- np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' ),
162
+ np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' ),
161
163
['ts_name,label,prediction' ,
162
164
'0,Mira,Mira' ,
163
165
'1,Classical_Cepheid,Classical_Cepheid' ,
164
166
'2,Mira,Mira' ,
165
167
'3,Classical_Cepheid,Classical_Cepheid' ,
166
168
'4,Mira,Mira' ])
167
169
finally :
168
- os .remove ('/tmp/cesium_prediction_results.csv' )
170
+ os .remove (matching_downloads_paths [ 0 ] )
169
171
170
172
171
173
@pytest .mark .parametrize ('model__type' , ['LinearSGDClassifier' ])
172
174
def test_download_prediction_csv_class_unlabeled (driver , project , unlabeled_prediction ):
173
175
driver .get ('/' )
174
176
_click_download (project .id , driver )
175
- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
177
+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
178
+ assert len (matching_downloads_paths ) == 1
176
179
try :
177
- result = np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' )
180
+ result = np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' )
178
181
assert result [0 ] == 'ts_name,prediction'
179
182
assert all ([el [0 ].isdigit () and el [1 ] == ',' and el [2 :] in
180
183
['Mira' , 'Classical_Cepheid' ] for el in result [1 :]])
181
184
finally :
182
- os .remove ('/tmp/cesium_prediction_results.csv' )
185
+ os .remove (matching_downloads_paths [ 0 ] )
183
186
184
187
185
188
def test_download_prediction_csv_class_prob (driver , project , dataset ,
186
189
featureset , model , prediction ):
187
190
driver .get ('/' )
188
191
_click_download (project .id , driver )
189
- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
192
+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
193
+ assert len (matching_downloads_paths ) == 1
190
194
try :
191
- result = pd .read_csv ('/tmp/cesium_prediction_results.csv' )
195
+ result = pd .read_csv (matching_downloads_paths [ 0 ] )
192
196
npt .assert_array_equal (result .ts_name , np .arange (5 ))
193
197
npt .assert_array_equal (result .label , ['Mira' , 'Classical_Cepheid' ,
194
198
'Mira' , 'Classical_Cepheid' ,
@@ -198,16 +202,17 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
198
202
[1 , 0 , 1 , 0 , 1 ])
199
203
assert (pred_probs .values >= 0.0 ).all ()
200
204
finally :
201
- os .remove ('/tmp/cesium_prediction_results.csv' )
205
+ os .remove (matching_downloads_paths [ 0 ] )
202
206
203
207
204
208
@pytest .mark .parametrize ('featureset__name, model__type' , [('regr' , 'LinearRegressor' )])
205
209
def test_download_prediction_csv_regr (driver , project , dataset , featureset , model , prediction ):
206
210
driver .get ('/' )
207
211
_click_download (project .id , driver )
208
- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
212
+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
213
+ assert len (matching_downloads_paths ) == 1
209
214
try :
210
- results = np .genfromtxt ('/tmp/cesium_prediction_results.csv' ,
215
+ results = np .genfromtxt (matching_downloads_paths [ 0 ] ,
211
216
dtype = 'str' , delimiter = ',' )
212
217
npt .assert_equal (results [0 ],
213
218
['ts_name' , 'label' , 'prediction' ])
@@ -219,7 +224,7 @@ def test_download_prediction_csv_regr(driver, project, dataset, featureset, mode
219
224
[3 , 2.2 , 2.2 ],
220
225
[4 , 3.1 , 3.1 ]])
221
226
finally :
222
- os .remove ('/tmp/cesium_prediction_results.csv' )
227
+ os .remove (matching_downloads_paths [ 0 ] )
223
228
224
229
225
230
def test_predict_specific_ts_name (driver , project , dataset , featureset , model ):
0 commit comments