Skip to content

Commit 0000b8c

Browse files
authored
Use ArviZ-stats (#232)
* use arviz_stats * fix imports * update python versions
1 parent cb2aab3 commit 0000b8c

File tree

9 files changed

+56
-59
lines changed

9 files changed

+56
-59
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ["3.10", "3.11"]
14+
python-version: ["3.11", "3.12", "3.13"]
1515

1616
name: Set up Python ${{ matrix.python-version }}
1717
steps:

docs/api_reference.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
1313
=============================
1414

1515
.. automodule:: pymc_bart
16-
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
16+
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ interpretation of those models and perform variable selection.
2929
Installation
3030
============
3131

32-
PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution <https://www.anaconda.com/products/individual#Downloads>`_, which has one-click installers available on all major platforms.
32+
PyMC-BART requires a working Python interpreter (3.11+). We recommend installing Python and key numerical libraries using the `Anaconda distribution <https://www.anaconda.com/products/individual#Downloads>`_, which has one-click installers available on all major platforms.
3333

3434
Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge.
3535

env-dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- pymc>=5.16.2,<=5.19.1
7-
- arviz>=0.18.0
6+
- pymc>=5.16.2,<=5.23.0
87
- numba
98
- matplotlib
109
- numpy
@@ -20,4 +19,5 @@ dependencies:
2019
- flake8
2120
- pip
2221
- pip:
22+
- arviz-stats[xarray]>=0.6.0
2323
- -e .

env.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- pymc>=5.16.2,<=5.19.1
7-
- arviz>=0.18.0
6+
- pymc>=5.16.2,<=5.23.0
87
- numba
98
- matplotlib
109
- numpy
1110
- pytensor
1211
- pip
1312
- pip:
1413
- pymc-bart
14+
- arviz-stats[xarray]>=0.6.0

pymc_bart/pgbart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def resample(
346346
new_particles.append(particles[idx].copy())
347347
else:
348348
new_particles.append(particles[idx])
349-
seen.append(idx)
349+
seen.append(int(idx))
350350

351351
particles[1:] = new_particles
352352

pymc_bart/utils.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
import warnings
55
from typing import Any, Callable, Optional, Union
66

7-
import arviz as az
87
import matplotlib.pyplot as plt
98
import numpy as np
109
import numpy.typing as npt
1110
import pytensor.tensor as pt
11+
from arviz_base import rcParams
12+
from arviz_stats.base import array_stats
1213
from numba import jit
1314
from pytensor.tensor.variable import Variable
1415
from scipy.interpolate import griddata
1516
from scipy.signal import savgol_filter
16-
from scipy.stats import norm
1717

1818
from .tree import Tree
1919

@@ -76,12 +76,12 @@ def _sample_posterior(
7676

7777

7878
def plot_convergence(
79-
idata: az.InferenceData,
79+
idata: Any,
8080
var_name: Optional[str] = None,
8181
kind: str = "ecdf",
8282
figsize: Optional[tuple[float, float]] = None,
8383
ax=None,
84-
) -> list[plt.Axes]:
84+
) -> None:
8585
"""
8686
Plot convergence diagnostics.
8787
@@ -102,39 +102,12 @@ def plot_convergence(
102102
-------
103103
list[ax] : matplotlib axes
104104
"""
105-
ess_threshold = idata["posterior"]["chain"].size * 100
106-
ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values)
107-
rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values)
108-
109-
if figsize is None:
110-
figsize = (10, 3)
111-
112-
if kind == "ecdf":
113-
kind_func: Callable[..., Any] = az.plot_ecdf
114-
sharey = True
115-
elif kind == "kde":
116-
kind_func = az.plot_kde
117-
sharey = False
118-
119-
if ax is None:
120-
_, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey)
121-
122-
for idx, (essi, rhati) in enumerate(zip(ess, rhat)):
123-
kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"})
124-
kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"})
125-
126-
ax[0].axvline(ess_threshold, color="0.7", ls="--")
127-
# Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile
128-
# scaled by the sample size and use it as a threshold.
129-
ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--")
130-
131-
ax[0].set_xlabel("ESS")
132-
ax[1].set_xlabel("R-hat")
133-
if kind == "kde":
134-
ax[0].set_yticks([])
135-
ax[1].set_yticks([])
136-
137-
return ax
105+
warnings.warn(
106+
"This function has been deprecated"
107+
"Use az.plot_convergence_dist() instead."
108+
"https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html",
109+
FutureWarning,
110+
)
138111

139112

140113
def plot_ice(
@@ -408,7 +381,7 @@ def identity(x):
408381
if var in var_discrete:
409382
_, idx_uni = np.unique(new_x, return_index=True)
410383
y_means = p_di.mean(0)[idx_uni]
411-
hdi = az.hdi(p_di)[idx_uni]
384+
hdi = array_stats.hdi(p_di, prob=rcParams["stats.ci_prob"], axis=0)[idx_uni]
412385
axes[count].errorbar(
413386
new_x[idx_uni],
414387
y_means,
@@ -418,11 +391,13 @@ def identity(x):
418391
)
419392
axes[count].set_xticks(new_x[idx_uni])
420393
else:
421-
az.plot_hdi(
394+
_plot_hdi(
422395
new_x,
423396
p_di,
424397
smooth=smooth,
425-
fill_kwargs={"alpha": alpha, "color": color},
398+
alpha=alpha,
399+
color=color,
400+
smooth_kwargs=smooth_kwargs,
426401
ax=axes[count],
427402
)
428403
if smooth:
@@ -659,7 +634,7 @@ def _create_pdp_data(
659634
def _smooth_mean(
660635
new_x: npt.NDArray,
661636
p_di: npt.NDArray,
662-
kind: str = "pdp",
637+
kind: str = "neutral",
663638
smooth_kwargs: Optional[dict[str, Any]] = None,
664639
) -> tuple[np.ndarray, np.ndarray]:
665640
"""
@@ -688,7 +663,10 @@ def _smooth_mean(
688663
smooth_kwargs.setdefault("polyorder", 2)
689664
x_data = np.linspace(np.nanmin(new_x), np.nanmax(new_x), 200)
690665
x_data[0] = (x_data[0] + x_data[1]) / 2
691-
if kind == "pdp":
666+
667+
if kind == "neutral":
668+
interp = griddata(new_x, p_di, x_data)
669+
elif kind == "pdp":
692670
interp = griddata(new_x, p_di.mean(0), x_data)
693671
else:
694672
interp = griddata(new_x, p_di.T, x_data)
@@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
800778

801779

802780
def compute_variable_importance( # noqa: PLR0915 PLR0912
803-
idata: az.InferenceData,
781+
idata: Any,
804782
bartrv: Variable,
805783
X: npt.NDArray,
806784
method: str = "VI",
@@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
904882
[pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
905883
)
906884
r2_mean[idx] = np.mean(r_2)
907-
r2_hdi[idx] = az.hdi(r_2)
885+
r2_hdi[idx] = array_stats.hdi(r_2, prob=rcParams["stats.ci_prob"])
908886
preds[idx] = predicted_subset.squeeze()
909887

910888
if method in ["backward", "backward_VI"]:
@@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
954932

955933
# Save values for plotting later
956934
r2_mean[i_var - init] = max_r_2
957-
r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars)
935+
r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars)
958936
preds[i_var - init] = least_important_samples.squeeze()
959937

960938
# extend current list of least important variable
@@ -1079,7 +1057,7 @@ def plot_variable_importance(
10791057
)
10801058
ax.fill_between(
10811059
[-0.5, n_vars - 0.5],
1082-
*az.hdi(r_2_ref),
1060+
*array_stats.hdi(r_2_ref, prob=rcParams["stats.ci_prob"]),
10831061
alpha=0.1,
10841062
color=plot_kwargs.get("color_ref", "grey"),
10851063
)
@@ -1229,3 +1207,22 @@ def pearsonr2(A, B):
12291207
am = A - np.mean(A)
12301208
bm = B - np.mean(B)
12311209
return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))
1210+
1211+
1212+
def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
1213+
x = np.asarray(x)
1214+
y = np.asarray(y)
1215+
hdi_prob = rcParams["stats.ci_prob"]
1216+
hdi_data = array_stats.hdi(y, hdi_prob, axis=0)
1217+
if smooth:
1218+
if isinstance(x[0], np.datetime64):
1219+
raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.")
1220+
1221+
x_data, y_data = _smooth_mean(x, hdi_data, smooth_kwargs=smooth_kwargs)
1222+
else:
1223+
idx = np.argsort(x)
1224+
x_data = x[idx]
1225+
y_data = hdi_data[idx]
1226+
1227+
ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha)
1228+
return ax

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
pymc>=5.16.2, <=5.23.0
2-
arviz>=0.18.0
1+
pymc>=5.16.2,<=5.23.0
2+
arviz-stats[xarray]>=0.6.0
33
numba
44
matplotlib
5-
numpy
5+
numpy>=2.0

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
"Development Status :: 5 - Production/Stable",
3030
"Programming Language :: Python",
3131
"Programming Language :: Python :: 3",
32-
"Programming Language :: Python :: 3.9",
33-
"Programming Language :: Python :: 3.10",
3432
"Programming Language :: Python :: 3.11",
33+
"Programming Language :: Python :: 3.12",
34+
"Programming Language :: Python :: 3.13",
3535
"License :: OSI Approved :: Apache Software License",
3636
"Intended Audience :: Science/Research",
3737
"Topic :: Scientific/Engineering",

0 commit comments

Comments
 (0)