Skip to content

Commit 4877490

Browse files
Added test for rfft/irfft with default n for even sizes
1 parent eda7834 commit 4877490

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

mkl_fft/tests/test_interfaces.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def test_scipy_rfft(norm, dtype):
6565
xx = mfi.scipy_fft.irfft(w, n=x.shape[0], norm=norm, workers=None, plan=None)
6666
tol = 64 * np.finfo(np.dtype(dtype)).eps
6767
assert np.allclose(x, xx, atol=tol, rtol=tol)
68+
69+
x = np.ones(510, dtype=dtype)
70+
w = mfi.scipy_fft.rfft(x, norm=norm, workers=None, plan=None)
71+
xx = mfi.scipy_fft.irfft(w, norm=norm, workers=None, plan=None)
72+
tol = 64 * np.finfo(np.dtype(dtype)).eps
73+
assert np.allclose(x, xx, atol=tol, rtol=tol)
6874

6975

7076
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
@@ -99,20 +105,26 @@ def test_numpy_fftn(norm, dtype):
99105

100106
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
101107
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
102-
def test_scipy_rftn(norm, dtype):
108+
def test_scipy_rfftn(norm, dtype):
103109
x = np.ones((37, 83), dtype=dtype)
104110
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
105-
xx = mfi.scipy_fft.ifftn(w, s=x.shape, norm=norm, workers=None, plan=None)
111+
xx = mfi.scipy_fft.irfftn(w, s=x.shape, norm=norm, workers=None, plan=None)
112+
tol = 64 * np.finfo(np.dtype(dtype)).eps
113+
assert np.allclose(x, xx, atol=tol, rtol=tol)
114+
115+
x = np.ones((36, 82), dtype=dtype)
116+
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
117+
xx = mfi.scipy_fft.irfftn(w, norm=norm, workers=None, plan=None)
106118
tol = 64 * np.finfo(np.dtype(dtype)).eps
107119
assert np.allclose(x, xx, atol=tol, rtol=tol)
108120

109121

110122
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
111123
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
112-
def test_numpy_rftn(norm, dtype):
124+
def test_numpy_rfftn(norm, dtype):
113125
x = np.ones((37, 83), dtype=dtype)
114126
w = mfi.numpy_fft.rfftn(x, norm=norm)
115-
xx = mfi.numpy_fft.ifftn(w, s=x.shape, norm=norm)
127+
xx = mfi.numpy_fft.irfftn(w, s=x.shape, norm=norm)
116128
tol = 64 * np.finfo(np.dtype(dtype)).eps
117129
assert np.allclose(x, xx, atol=tol, rtol=tol)
118130

0 commit comments

Comments
 (0)