Skip to content

Commit 3e220d9

Browse files
extended tests
1 parent e343631 commit 3e220d9

File tree

2 files changed

+183
-66
lines changed

2 files changed

+183
-66
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[run]
2+
concurrency = multiprocessing
23
source = cmethods
34
omit = *tests*

tests/test_methods.py

Lines changed: 182 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,33 @@
99
from typing import List, Tuple
1010

1111
import numpy as np
12+
import pytest
1213
import xarray as xr
1314
from sklearn.metrics import mean_squared_error
1415

15-
from cmethods.CMethods import CMethods
16+
from cmethods.CMethods import CMethods, UnknownMethodError
1617

1718

1819
class TestMethods(unittest.TestCase):
20+
def setUp(self) -> None:
21+
obsh_add, obsp_add, simh_add, simp_add = self.get_datasets(kind="+")
22+
obsh_mult, obsp_mult, simh_mult, simp_mult = self.get_datasets(kind="*")
23+
24+
self.data = {
25+
"+": {
26+
"obsh": obsh_add["+"],
27+
"obsp": obsp_add["+"],
28+
"simh": simh_add["+"],
29+
"simp": simp_add["+"],
30+
},
31+
"*": {
32+
"obsh": obsh_mult["*"],
33+
"obsp": obsp_mult["*"],
34+
"simh": simh_mult["*"],
35+
"simp": simp_mult["*"],
36+
},
37+
}
38+
1939
def get_datasets(
2040
self,
2141
kind: str,
@@ -95,180 +115,276 @@ def test_linear_scaling(self) -> None:
95115
"""Tests the linear scaling method"""
96116

97117
for kind in ("+", "*"):
98-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
99118
ls_result = CMethods().linear_scaling(
100-
obs=obsh[kind][:, 0, 0],
101-
simh=simh[kind][:, 0, 0],
102-
simp=simp[kind][:, 0, 0],
119+
obs=self.data[kind]["obsh"][:, 0, 0],
120+
simh=self.data[kind]["simh"][:, 0, 0],
121+
simp=self.data[kind]["simp"][:, 0, 0],
103122
kind=kind,
104123
)
105124
assert isinstance(ls_result, xr.core.dataarray.DataArray)
106125
assert mean_squared_error(
107-
ls_result, obsp[kind][:, 0, 0], squared=False
126+
ls_result, self.data[kind]["obsp"][:, 0, 0], squared=False
108127
) < mean_squared_error(
109-
simp[kind][:, 0, 0], obsp[kind][:, 0, 0], squared=False
128+
self.data[kind]["simp"][:, 0, 0],
129+
self.data[kind]["obsp"][:, 0, 0],
130+
squared=False,
110131
)
111132

112133
def test_variance_scaling(self) -> None:
113134
"""Tests the variance scaling method"""
114-
115-
obsh, obsp, simh, simp = self.get_datasets(kind="+")
135+
kind = "+"
116136
vs_result = CMethods().variance_scaling(
117-
obs=obsh["+"][:, 0, 0],
118-
simh=simh["+"][:, 0, 0],
119-
simp=simp["+"][:, 0, 0],
137+
obs=self.data[kind]["obsh"][:, 0, 0],
138+
simh=self.data[kind]["simh"][:, 0, 0],
139+
simp=self.data[kind]["simp"][:, 0, 0],
120140
kind="+",
121141
)
122142
assert isinstance(vs_result, xr.core.dataarray.DataArray)
123143
assert mean_squared_error(
124-
vs_result, obsp["+"][:, 0, 0], squared=False
125-
) < mean_squared_error(simp["+"][:, 0, 0], obsp["+"][:, 0, 0], squared=False)
144+
vs_result, self.data[kind]["obsp"][:, 0, 0], squared=False
145+
) < mean_squared_error(
146+
self.data[kind]["simp"][:, 0, 0],
147+
self.data[kind]["obsp"][:, 0, 0],
148+
squared=False,
149+
)
126150

127151
def test_delta_method(self) -> None:
128152
"""Tests the delta method"""
129153

130154
for kind in ("+", "*"):
131-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
132155
dm_result = CMethods().delta_method(
133-
obs=obsh[kind][:, 0, 0],
134-
simh=simh[kind][:, 0, 0],
135-
simp=simp[kind][:, 0, 0],
156+
obs=self.data[kind]["obsh"][:, 0, 0],
157+
simh=self.data[kind]["simh"][:, 0, 0],
158+
simp=self.data[kind]["simp"][:, 0, 0],
136159
kind=kind,
137160
)
138161
assert isinstance(dm_result, xr.core.dataarray.DataArray)
139162
assert mean_squared_error(
140-
dm_result, obsp[kind][:, 0, 0], squared=False
163+
dm_result, self.data[kind]["obsp"][:, 0, 0], squared=False
141164
) < mean_squared_error(
142-
simp[kind][:, 0, 0], obsp[kind][:, 0, 0], squared=False
165+
self.data[kind]["simp"][:, 0, 0],
166+
self.data[kind]["obsp"][:, 0, 0],
167+
squared=False,
143168
)
144169

145170
def test_quantile_mapping(self) -> None:
146171
"""Tests the quantile mapping method"""
147172

148173
for kind in ("+", "*"):
149-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
150174
qm_result = CMethods().quantile_mapping(
151-
obs=obsh[kind][:, 0, 0],
152-
simh=simh[kind][:, 0, 0],
153-
simp=simp[kind][:, 0, 0],
175+
obs=self.data[kind]["obsh"][:, 0, 0],
176+
simh=self.data[kind]["simh"][:, 0, 0],
177+
simp=self.data[kind]["simp"][:, 0, 0],
154178
n_quantiles=100,
155179
kind=kind,
156180
)
157181
assert isinstance(qm_result, xr.core.dataarray.DataArray)
158182
assert mean_squared_error(
159-
qm_result, obsp[kind][:, 0, 0], squared=False
183+
qm_result, self.data[kind]["obsp"][:, 0, 0], squared=False
160184
) < mean_squared_error(
161-
simp[kind][:, 0, 0], obsp[kind][:, 0, 0], squared=False
185+
self.data[kind]["simp"][:, 0, 0],
186+
self.data[kind]["obsp"][:, 0, 0],
187+
squared=False,
162188
)
163189

164190
def test_detrended_quantile_mapping(self) -> None:
165191
"""Tests the detrendeed quantile mapping method"""
166192

167193
for kind in ("+", "*"):
168-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
169194
dqm_result = CMethods().quantile_mapping(
170-
obs=obsh[kind][:, 0, 0],
171-
simh=simh[kind][:, 0, 0],
172-
simp=simp[kind][:, 0, 0],
195+
obs=self.data[kind]["obsh"][:, 0, 0],
196+
simh=self.data[kind]["simh"][:, 0, 0],
197+
simp=self.data[kind]["simp"][:, 0, 0],
173198
n_quantiles=100,
174199
kind=kind,
175200
detrended=True,
176201
)
177202
assert isinstance(dqm_result, xr.core.dataarray.DataArray)
178203
assert mean_squared_error(
179-
dqm_result, obsp[kind][:, 0, 0], squared=False
204+
dqm_result, self.data[kind]["obsp"][:, 0, 0], squared=False
180205
) < mean_squared_error(
181-
simp[kind][:, 0, 0], obsp[kind][:, 0, 0], squared=False
206+
self.data[kind]["simp"][:, 0, 0],
207+
self.data[kind]["obsp"][:, 0, 0],
208+
squared=False,
182209
)
183210

184211
def test_quantile_delta_mapping(self) -> None:
185212
"""Tests the quantile delta mapping method"""
186213

187214
for kind in ("+", "*"):
188-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
189215
qdm_result = CMethods().quantile_delta_mapping(
190-
obs=obsh[kind][:, 0, 0],
191-
simh=simh[kind][:, 0, 0],
192-
simp=simp[kind][:, 0, 0],
216+
obs=self.data[kind]["obsh"][:, 0, 0],
217+
simh=self.data[kind]["simh"][:, 0, 0],
218+
simp=self.data[kind]["simp"][:, 0, 0],
193219
n_quantiles=100,
194220
kind=kind,
195221
)
196222

197223
assert isinstance(qdm_result, xr.core.dataarray.DataArray)
198224
assert mean_squared_error(
199-
qdm_result, obsp[kind][:, 0, 0], squared=False
225+
qdm_result, self.data[kind]["obsp"][:, 0, 0], squared=False
200226
) < mean_squared_error(
201-
simp[kind][:, 0, 0], obsp[kind][:, 0, 0], squared=False
227+
self.data[kind]["simp"][:, 0, 0],
228+
self.data[kind]["obsp"][:, 0, 0],
229+
squared=False,
202230
)
203231

204232
def test_3d_sclaing_methods(self) -> None:
205233
"""Tests the scaling based methods for 3-dimentsional data sets"""
206234

207235
kind = "+"
208-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
209236
for method in CMethods().SCALING_METHODS:
210237
result = CMethods().adjust_3d(
211238
method=method,
212-
obs=obsh[kind],
213-
simh=simh[kind],
214-
simp=simp[kind],
239+
obs=self.data[kind]["obsh"],
240+
simh=self.data[kind]["simh"],
241+
simp=self.data[kind]["simp"],
215242
kind=kind,
216243
goup="time.month", # default
217244
)
218245
assert isinstance(result, xr.core.dataarray.DataArray)
219-
for lat in range(len(obsh.lat)):
220-
for lon in range(len(obsh.lon)):
246+
for lat in range(len(self.data[kind]["obsh"].lat)):
247+
for lon in range(len(self.data[kind]["obsh"].lon)):
221248
assert mean_squared_error(
222-
result[:, lat, lon], obsp[kind][:, lat, lon], squared=False
249+
result[:, lat, lon],
250+
self.data[kind]["obsp"][:, lat, lon],
251+
squared=False,
223252
) < mean_squared_error(
224-
simp[kind][:, lat, lon],
225-
obsp[kind][:, lat, lon],
253+
self.data[kind]["simp"][:, lat, lon],
254+
self.data[kind]["obsp"][:, lat, lon],
226255
squared=False,
227256
)
228257

229258
def test_3d_distribution_methods(self) -> None:
230259
"""Tests the distribution based methods for 3-dimentsional data sets"""
231260

232261
for kind in ("+", "*"):
233-
obsh, obsp, simh, simp = self.get_datasets(kind=kind)
234262
for method in CMethods().DISTRIBUTION_METHODS:
235263
result = CMethods().adjust_3d(
236264
method=method,
237-
obs=obsh[kind],
238-
simh=simh[kind],
239-
simp=simp[kind],
240-
n_quantiles=100,
265+
obs=self.data[kind]["obsh"],
266+
simh=self.data[kind]["simh"],
267+
simp=self.data[kind]["simp"],
268+
n_quantiles=25,
241269
)
242270
assert isinstance(result, xr.core.dataarray.DataArray)
243-
for lat in range(len(obsh.lat)):
244-
for lon in range(len(obsh.lon)):
271+
for lat in range(len(self.data[kind]["obsh"].lat)):
272+
for lon in range(len(self.data[kind]["obsh"].lon)):
245273
assert mean_squared_error(
246-
result[:, lat, lon], obsp[kind][:, lat, lon], squared=False
274+
result[:, lat, lon],
275+
self.data[kind]["obsp"][:, lat, lon],
276+
squared=False,
247277
) < mean_squared_error(
248-
simp[kind][:, lat, lon],
249-
obsp[kind][:, lat, lon],
278+
self.data[kind]["simp"][:, lat, lon],
279+
self.data[kind]["obsp"][:, lat, lon],
250280
squared=False,
251281
)
252282

253283
def test_n_jobs(self) -> None:
254-
obsh, obsp, simh, simp = self.get_datasets(kind="+")
284+
kind = "+"
255285
result = CMethods().adjust_3d(
256286
method="quantile_mapping",
257-
obs=obsh["+"],
258-
simh=simh["+"],
259-
simp=simp["+"],
260-
n_quantiles=100,
287+
obs=self.data[kind]["obsh"],
288+
simh=self.data[kind]["simh"],
289+
simp=self.data[kind]["simp"],
290+
n_quantiles=25,
261291
n_jobs=2,
262292
)
263293
assert isinstance(result, xr.core.dataarray.DataArray)
264-
for lat in range(len(obsh.lat)):
265-
for lon in range(len(obsh.lon)):
294+
for lat in range(len(self.data[kind]["obsh"].lat)):
295+
for lon in range(len(self.data[kind]["obsh"].lon)):
266296
assert mean_squared_error(
267-
result[:, lat, lon], obsp["+"][:, lat, lon], squared=False
297+
result[:, lat, lon],
298+
self.data[kind]["obsp"][:, lat, lon],
299+
squared=False,
268300
) < mean_squared_error(
269-
simp["+"][:, lat, lon], obsp["+"][:, lat, lon], squared=False
301+
self.data[kind]["simp"][:, lat, lon],
302+
self.data[kind]["obsp"][:, lat, lon],
303+
squared=False,
270304
)
271305

306+
def test_get_available_methods(self) -> None:
307+
assert CMethods().get_available_methods() == [
308+
"linear_scaling",
309+
"variance_scaling",
310+
"delta_method",
311+
"quantile_mapping",
312+
"quantile_delta_mapping",
313+
]
314+
315+
def test_unknown_method(self) -> None:
316+
with pytest.raises(UnknownMethodError):
317+
CMethods.get_function("LOCI_INTENSITY_SCALING")
318+
319+
kind = "+"
320+
with pytest.raises(UnknownMethodError):
321+
CMethods().adjust_3d(
322+
method="distribution_mapping",
323+
obs=self.data[kind]["obsh"],
324+
simh=self.data[kind]["simh"],
325+
simp=self.data[kind]["simp"],
326+
kind=kind,
327+
)
328+
329+
def test_not_implemented_methods(self) -> None:
330+
kind = "+"
331+
with pytest.raises(ValueError):
332+
CMethods.empirical_quantile_mapping(
333+
self.data[kind]["obsh"],
334+
self.data[kind]["simh"],
335+
self.data[kind]["simp"],
336+
n_quantiles=10,
337+
)
338+
339+
def test_invalid_adjustment_type(self) -> None:
340+
kind = "+"
341+
with pytest.raises(ValueError):
342+
CMethods.linear_scaling(
343+
self.data[kind]["obsh"],
344+
self.data[kind]["simh"],
345+
self.data[kind]["simp"],
346+
kind="/",
347+
)
348+
with pytest.raises(ValueError):
349+
CMethods.variance_scaling(
350+
self.data[kind]["obsh"],
351+
self.data[kind]["simh"],
352+
self.data[kind]["simp"],
353+
kind="*",
354+
)
355+
with pytest.raises(ValueError):
356+
CMethods.delta_method(
357+
self.data[kind]["obsh"],
358+
self.data[kind]["simh"],
359+
self.data[kind]["simp"],
360+
kind="/",
361+
)
362+
with pytest.raises(ValueError):
363+
CMethods.quantile_mapping(
364+
self.data[kind]["obsh"],
365+
self.data[kind]["simh"],
366+
self.data[kind]["simp"],
367+
kind="/",
368+
n_quantiles=10,
369+
)
370+
with pytest.raises(ValueError):
371+
CMethods.quantile_delta_mapping(
372+
self.data[kind]["obsh"],
373+
self.data[kind]["simh"],
374+
self.data[kind]["simp"],
375+
kind="/",
376+
n_quantiles=10,
377+
)
378+
379+
def test_get_pdf(self) -> None:
380+
assert (CMethods.get_pdf(np.arange(10), [0, 5, 11]) == np.array((5, 5))).all()
381+
382+
def test_get_adjusted_scaling_factor(self) -> None:
383+
assert CMethods().get_adjusted_scaling_factor(10, 5) == 5
384+
assert CMethods().get_adjusted_scaling_factor(10, 11) == 10
385+
assert CMethods().get_adjusted_scaling_factor(-10, -11) == -10
386+
assert CMethods().get_adjusted_scaling_factor(-11, -10) == -10
387+
272388

273389
if __name__ == "__main__":
274390
unittest.main()

0 commit comments

Comments
 (0)