Skip to content

Commit 50bcc41

Browse files
committed
Fixed basis update python code
1 parent 9713f43 commit 50bcc41

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

stochtree/data.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ def update_basis(self, basis: np.array):
5858
basis : np.array
5959
Numpy array of basis vectors.
6060
"""
61-
basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis
62-
n, p = basis_.shape
63-
basis_rowmajor = np.ascontiguousarray(basis_)
6461
if not self.has_basis():
6562
raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.")
6663
if not isinstance(basis, np.ndarray):
6764
raise ValueError("basis must be a numpy array.")
68-
if basis.ndim != 2:
69-
raise ValueError("basis must be a 2-dimensional numpy array.")
65+
if np.ndim(basis) == 1:
66+
basis_ = np.expand_dims(basis, 1)
67+
elif np.ndim(basis) == 2:
68+
basis_ = basis
69+
else:
70+
raise ValueError("basis must be a numpy array with one or two dimension.")
71+
n, p = basis_.shape
72+
basis_rowmajor = np.ascontiguousarray(basis_)
7073
if self.num_basis() != p:
7174
raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).")
7275
if self.num_observations() != n:

0 commit comments

Comments
 (0)