Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/atmos_flux_inversion/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from numpy import fromfunction, asarray, hstack, flip
from numpy import exp, square, fmin, sqrt
from numpy import logical_or, concatenate, isnan
from numpy import where
from numpy import where, prod
from numpy import sum as array_sum
from scipy.special import gamma, kv as K_nu
from scipy.sparse.linalg.interface import LinearOperator
from scipy.fftpack import dctn, fftn

import pyfftw.interfaces.cache
from pyfftw import next_fast_len
Expand Down Expand Up @@ -284,8 +285,12 @@ def from_array(cls, corr_array, is_cyclic=True):
self = cls(shape, computational_shape)

for axis in reversed(range(ndims)):
sub_index = [
slice(None) for i in range(ndims)
]
sub_index[axis] = slice(1, -1)
corr_array = concatenate(
[corr_array, flip(corr_array[1:-1], axis)],
[corr_array, flip(corr_array[tuple(sub_index)], axis)],
axis=axis)

# Advantages over dctn: guaranteed same format and gets
Expand Down Expand Up @@ -439,6 +444,22 @@ def kron(self, other):
other._fourier_near_zero[other_index])
return newinst

def det(self):
"""Find the determinant of the operator.

Returns
-------
float
"""
correlations = self._ifft(self._corr_fourier)
if self._is_cyclic:
spectrum = fftn(correlations, shape=self._underlying_shape).real
# The order is different from la.eigvalsh, but that
# doesn't matter
return prod(spectrum)
spectrum = dctn(correlations, type=1, shape=self._underlying_shape)
return prod(spectrum)


def make_matrix(corr_func, shape):
"""Make a correlation matrix for a domain with shape `shape`.
Expand Down
74 changes: 72 additions & 2 deletions src/atmos_flux_inversion/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# implicitly called on the operator2.dot(chunk) that usually follows
# this.
from numpy import einsum
from numpy import concatenate, zeros, nonzero
from numpy import concatenate, zeros, nonzero, prod
from numpy import asarray, atleast_2d, stack, where, sqrt

from scipy.sparse.linalg import lgmres
Expand Down Expand Up @@ -278,6 +278,24 @@ def matrix_sqrt(mat):
cls=type(mat)))


def matrix_determinant(operator):
"""Find the determinant of the operator.

Parameters
----------
operator: LinearOperator

Returns
-------
determinant: float
"""
if hasattr(operator, "det"):
return operator.det()
if isinstance(operator, DaskMatrixLinearOperator):
operator = operator.A
return la.det(operator)


class DaskKroneckerProductOperator(DaskLinearOperator):
"""Operator for Kronecker product.

Expand Down Expand Up @@ -491,6 +509,25 @@ def quadratic_form(self, mat):
operator2.dot(chunk))
return result

def det(self):
"""Find the determinant of the operator.

Returns
-------
float
"""
op1 = self._operator1
op2 = self._operator2
if (
self.shape[0] != self.shape[1] or
op1.shape[0] != op1.shape[1]
):
raise ValueError("Determinant only defined for square operators.")
return (
matrix_determinant(op1) ** op2.shape[0] *
matrix_determinant(op2) ** op1.shape[0]
)


class SchmidtKroneckerProduct(DaskLinearOperator):
"""Kronecker product of two operators using Schmidt decomposition.
Expand Down Expand Up @@ -580,6 +617,25 @@ def _matvec(self, vector):

return asarray(result)

def det(self):
"""Find the determinant of the operator.

Returns
-------
float
"""
op1 = self._operator1
op2 = self._operator2
if (
self.shape[0] != self.shape[1] or
op1.shape[0] != op1.shape[1]
):
raise ValueError("Determinant only defined for square operators.")
return (
matrix_determinant(op1) ** op2.shape[0] *
matrix_determinant(op2) ** op1.shape[0]
)


class SelfAdjointLinearOperator(DaskLinearOperator):
"""Self-adjoint linear operators.
Expand Down Expand Up @@ -705,5 +761,19 @@ def solve(self, vector):
return where(self._diag_near_zero, 0, result)

def sqrt(self):
"""Find S such that S.T @ S == self."""
"""Find S such that S.T @ S == self.

Returns
-------
LinearOperator
"""
return DiagonalOperator(sqrt(self._diag))

def det(self):
"""Find the determinant of the operator.

Returns
-------
float
"""
return prod(self._diag)
160 changes: 160 additions & 0 deletions src/atmos_flux_inversion/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,16 @@ def test_acyclic_from_array(self):
np_tst.assert_allclose(op.dot(np.eye(*mat.shape)),
mat)

def test_acyclic_from_array2d(self):
"""Test from_array with correlations assumed acyclic."""
array = [[1, .5, .25, .125, .0625, .03125],
[.5, .25, .125, .0625, .03125, .015625]]
# At one point this operation crashed, so this is an
# "assertDoesNotRaise"
(atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.
from_array(array, False))


class TestSchmidtKroneckerProduct(unittest2.TestCase):
"""Test the Schmidt Kronecker product implementation for LinearOperators.
Expand Down Expand Up @@ -3163,5 +3173,155 @@ def test_simple_variances(self):
np_tst.assert_allclose(corr_matrix, np.diag(series.values))


class TestDeterminants(unittest2.TestCase):
"""Test the determinant-finding abilities of operators."""

def test_matrix_determinant(self):
"""Test the matrix_determinant function."""
test_ops = [
np.eye(3),
np.eye(10, dtype=DTYPE),
atmos_flux_inversion.linalg.DiagonalOperator(np.ones(10)),
atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_array([1, 0, 0, 0, 0]),
atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_array([1, 0, 0, 0, 0], False),
atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_array(
[[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
),
atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_array(
[[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]],
False
),
]

for test_op in test_ops:
with self.subTest(test_op=test_op):
result = atmos_flux_inversion.linalg.matrix_determinant(
test_op
)
self.assertIsInstance(result, float)
self.assertEqual(result, 1.0)

def test_diagonal_determinant(self):
"""Test determinants of DiagonalOperators."""
test_cases = (
np.arange(10),
np.arange(1, 10),
np.ones(15),
)

for test_case in test_cases:
with self.subTest(test_case=test_case):
op = atmos_flux_inversion.linalg.DiagonalOperator(
test_case
)
self.assertEqual(op.det(), np.prod(test_case))

def test_kronecker_determinant_simple(self):
"""Test determinants of Kronecker products of simple inputs."""
kron_classes = (
atmos_flux_inversion.linalg.DaskKroneckerProductOperator,
atmos_flux_inversion.correlations.SchmidtKroneckerProduct,
)
test_mats = (np.eye(3), np.eye(10, dtype=DTYPE))
test_ops = (
atmos_flux_inversion.linalg.DiagonalOperator(np.ones(10)),
)

for kron_class in kron_classes:
for left, right in itertools.product(
test_mats,
test_mats + test_ops
):
with self.subTest(
kron_class=kron_class,
left=left, right=right
):
kron_op = kron_class(left, right)
result = kron_op.det()
self.assertEqual(result, 1)

def test_kronecker_determinant_hard(self):
"""Test determinants of KroneckerProducts of other inputs."""
kron_classes = (
atmos_flux_inversion.linalg.DaskKroneckerProductOperator,
atmos_flux_inversion.correlations.SchmidtKroneckerProduct,
)
test_mats = (
np.eye(3),
np.eye(10, dtype=DTYPE),
np.arange(4).reshape(2, 2),
np.random.rand(4, 4),
np.random.rand(5, 5),
)

for kron_class in kron_classes:
for left, right in itertools.product(
test_mats,
test_mats
):
with self.subTest(
kron_class=kron_class,
left=left, right=right
):
kron_op = kron_class(left, right)
result = kron_op.det()
expected = la.det(np.kron(left, right))
self.assertAlmostEqual(result, expected)

def test_cyclic_fourier_determinants(self):
"""Test determinants of periodic HomogeneousIsotropicCorrelation."""
from_function = (atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_function)
test_dists = (.3, 1, 3,)
test_shapes = (10, 15, (4, 5))

for test_dist, corr_class, shape in itertools.product(
test_dists,
atmos_flux_inversion.correlations.
DistanceCorrelationFunction.__subclasses__(),
test_shapes
):
with self.subTest(test_dist=test_dist,
corr_class=corr_class.__name__,
shape=shape):
op = from_function(corr_class(test_dist), shape,
is_cyclic=True)
mat = op.dot(np.eye(*op.shape))

self.assertAlmostEqual(op.det(), la.det(mat), 6)

def test_acyclic_fourier_determinants(self):
"""Test determinants of aperiodic HomogeneousIsotropicCorrelation."""
from_function = (atmos_flux_inversion.correlations.
HomogeneousIsotropicCorrelation.from_function)
test_dists = (.3, 1, 3,)
test_shapes = (10, 15, (4, 5))

for test_dist, corr_class, shape in itertools.product(
test_dists,
atmos_flux_inversion.correlations.
DistanceCorrelationFunction.__subclasses__(),
test_shapes
):
with self.subTest(test_dist=test_dist,
corr_class=corr_class.__name__,
shape=shape):
op = from_function(corr_class(test_dist), shape,
is_cyclic=False)
mat = op.dot(np.eye(*op.shape))

self.assertAlmostEqual(op.det(), la.det(mat), 4)


if __name__ == "__main__":
unittest2.main()