Skip to content

Commit 41ee78f

Browse files
committed
Fix prediction results download path bug in tests
1 parent 199867e commit 41ee78f

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

cesium_app/tests/frontend/test_predict.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas as pd
1111
import json
1212
import subprocess
13+
import glob
1314
from cesium_app.model_util import create_token_user
1415

1516

@@ -154,41 +155,44 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
154155
model, prediction):
155156
driver.get('/')
156157
_click_download(project.id, driver)
157-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
158+
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
159+
assert len(matching_downloads_paths) == 1
158160
try:
159161
npt.assert_equal(
160-
np.genfromtxt('/tmp/cesium_prediction_results.csv', dtype='str'),
162+
np.genfromtxt(matching_downloads_paths[0], dtype='str'),
161163
['ts_name,label,prediction',
162164
'0,Mira,Mira',
163165
'1,Classical_Cepheid,Classical_Cepheid',
164166
'2,Mira,Mira',
165167
'3,Classical_Cepheid,Classical_Cepheid',
166168
'4,Mira,Mira'])
167169
finally:
168-
os.remove('/tmp/cesium_prediction_results.csv')
170+
os.remove(matching_downloads_paths[0])
169171

170172

171173
@pytest.mark.parametrize('model__type', ['LinearSGDClassifier'])
172174
def test_download_prediction_csv_class_unlabeled(driver, project, unlabeled_prediction):
173175
driver.get('/')
174176
_click_download(project.id, driver)
175-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
177+
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
178+
assert len(matching_downloads_paths) == 1
176179
try:
177-
result = np.genfromtxt('/tmp/cesium_prediction_results.csv', dtype='str')
180+
result = np.genfromtxt(matching_downloads_paths[0], dtype='str')
178181
assert result[0] == 'ts_name,prediction'
179182
assert all([el[0].isdigit() and el[1] == ',' and el[2:] in
180183
['Mira', 'Classical_Cepheid'] for el in result[1:]])
181184
finally:
182-
os.remove('/tmp/cesium_prediction_results.csv')
185+
os.remove(matching_downloads_paths[0])
183186

184187

185188
def test_download_prediction_csv_class_prob(driver, project, dataset,
186189
featureset, model, prediction):
187190
driver.get('/')
188191
_click_download(project.id, driver)
189-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
192+
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
193+
assert len(matching_downloads_paths) == 1
190194
try:
191-
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
195+
result = pd.read_csv(matching_downloads_paths[0])
192196
npt.assert_array_equal(result.ts_name, np.arange(5))
193197
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
194198
'Mira', 'Classical_Cepheid',
@@ -198,16 +202,17 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
198202
[1, 0, 1, 0, 1])
199203
assert (pred_probs.values >= 0.0).all()
200204
finally:
201-
os.remove('/tmp/cesium_prediction_results.csv')
205+
os.remove(matching_downloads_paths[0])
202206

203207

204208
@pytest.mark.parametrize('featureset__name, model__type', [('regr', 'LinearRegressor')])
205209
def test_download_prediction_csv_regr(driver, project, dataset, featureset, model, prediction):
206210
driver.get('/')
207211
_click_download(project.id, driver)
208-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
212+
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
213+
assert len(matching_downloads_paths) == 1
209214
try:
210-
results = np.genfromtxt('/tmp/cesium_prediction_results.csv',
215+
results = np.genfromtxt(matching_downloads_paths[0],
211216
dtype='str', delimiter=',')
212217
npt.assert_equal(results[0],
213218
['ts_name', 'label', 'prediction'])
@@ -219,7 +224,7 @@ def test_download_prediction_csv_regr(driver, project, dataset, featureset, mode
219224
[3, 2.2, 2.2],
220225
[4, 3.1, 3.1]])
221226
finally:
222-
os.remove('/tmp/cesium_prediction_results.csv')
227+
os.remove(matching_downloads_paths[0])
223228

224229

225230
def test_predict_specific_ts_name(driver, project, dataset, featureset, model):

0 commit comments

Comments
 (0)