Skip to content

Commit 5bcbaa8

Browse files
Merge pull request #314 from astro-informatics/change_legacy_import
Change legacy import
2 parents 258cfbb + 2acb6d2 commit 5bcbaa8

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

.codecov.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
ignore:
22
- "**/utils*"
3+
- "tests/*"
4+
35
coverage:
46
status:
57
patch: false

harmonic/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import numpy as np
33
import harmonic as hm
44
import getdist
5-
from harmonic import model_legacy
6-
from getdist import plots
75
import matplotlib as plt
86

97

@@ -263,7 +261,7 @@ def cross_validation(
263261
domains: List,
264262
hyper_parameters: List,
265263
nfold=2,
266-
modelClass=model_legacy.KernelDensityEstimate,
264+
modelClass=None,
267265
seed: int = -1,
268266
) -> List:
269267
"""Perform n-fold validation for given model using chains to be split into
@@ -285,8 +283,8 @@ def cross_validation(
285283
hyper_parameters (List): List of hyper_parameters where each entry is a
286284
hyper_parameter list to be considered.
287285
288-
modelClass (Model): Model that is being cross validated (default =
289-
KernelDensityEstimate).
286+
modelClass (Model): Model that is being cross validated (defaults to
287+
KernelDensityEstimate inside function).
290288
291289
seed (int): Seed for random number generator when drawing the chains
292290
(if this is negative the seed is not set).
@@ -301,6 +299,9 @@ def cross_validation(
301299
302300
"""
303301

302+
if modelClass is None:
303+
modelClass = hm.model_legacy.KernelDensityEstimate
304+
304305
ln_validation_variances = np.zeros((nfold, len(hyper_parameters)))
305306

306307
if seed > 0:

0 commit comments

Comments
 (0)