4
4
import warnings
5
5
from typing import Any , Callable , Optional , Union
6
6
7
- import arviz as az
8
7
import matplotlib .pyplot as plt
9
8
import numpy as np
10
9
import numpy .typing as npt
11
10
import pytensor .tensor as pt
11
+ from arviz_base import rcParams
12
+ from arviz_stats .base import array_stats
12
13
from numba import jit
13
14
from pytensor .tensor .variable import Variable
14
15
from scipy .interpolate import griddata
15
16
from scipy .signal import savgol_filter
16
- from scipy .stats import norm
17
17
18
18
from .tree import Tree
19
19
@@ -76,12 +76,12 @@ def _sample_posterior(
76
76
77
77
78
78
def plot_convergence (
79
- idata : az . InferenceData ,
79
+ idata : Any ,
80
80
var_name : Optional [str ] = None ,
81
81
kind : str = "ecdf" ,
82
82
figsize : Optional [tuple [float , float ]] = None ,
83
83
ax = None ,
84
- ) -> list [ plt . Axes ] :
84
+ ) -> None :
85
85
"""
86
86
Plot convergence diagnostics.
87
87
@@ -102,39 +102,12 @@ def plot_convergence(
102
102
-------
103
103
list[ax] : matplotlib axes
104
104
"""
105
- ess_threshold = idata ["posterior" ]["chain" ].size * 100
106
- ess = np .atleast_2d (az .ess (idata , method = "bulk" , var_names = var_name )[var_name ].values )
107
- rhat = np .atleast_2d (az .rhat (idata , var_names = var_name )[var_name ].values )
108
-
109
- if figsize is None :
110
- figsize = (10 , 3 )
111
-
112
- if kind == "ecdf" :
113
- kind_func : Callable [..., Any ] = az .plot_ecdf
114
- sharey = True
115
- elif kind == "kde" :
116
- kind_func = az .plot_kde
117
- sharey = False
118
-
119
- if ax is None :
120
- _ , ax = plt .subplots (1 , 2 , figsize = figsize , sharex = "col" , sharey = sharey )
121
-
122
- for idx , (essi , rhati ) in enumerate (zip (ess , rhat )):
123
- kind_func (essi , ax = ax [0 ], plot_kwargs = {"color" : f"C{ idx } " })
124
- kind_func (rhati , ax = ax [1 ], plot_kwargs = {"color" : f"C{ idx } " })
125
-
126
- ax [0 ].axvline (ess_threshold , color = "0.7" , ls = "--" )
127
- # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile
128
- # scaled by the sample size and use it as a threshold.
129
- ax [1 ].axvline (norm (1 , 0.005 ).ppf (0.99 ** (1 / ess .size )), color = "0.7" , ls = "--" )
130
-
131
- ax [0 ].set_xlabel ("ESS" )
132
- ax [1 ].set_xlabel ("R-hat" )
133
- if kind == "kde" :
134
- ax [0 ].set_yticks ([])
135
- ax [1 ].set_yticks ([])
136
-
137
- return ax
105
+ warnings .warn (
106
+ "This function has been deprecated"
107
+ "Use az.plot_convergence_dist() instead."
108
+ "https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html" ,
109
+ FutureWarning ,
110
+ )
138
111
139
112
140
113
def plot_ice (
@@ -408,7 +381,7 @@ def identity(x):
408
381
if var in var_discrete :
409
382
_ , idx_uni = np .unique (new_x , return_index = True )
410
383
y_means = p_di .mean (0 )[idx_uni ]
411
- hdi = az .hdi (p_di )[idx_uni ]
384
+ hdi = array_stats .hdi (p_di , prob = rcParams [ "stats.ci_prob" ], axis = 0 )[idx_uni ]
412
385
axes [count ].errorbar (
413
386
new_x [idx_uni ],
414
387
y_means ,
@@ -418,11 +391,13 @@ def identity(x):
418
391
)
419
392
axes [count ].set_xticks (new_x [idx_uni ])
420
393
else :
421
- az . plot_hdi (
394
+ _plot_hdi (
422
395
new_x ,
423
396
p_di ,
424
397
smooth = smooth ,
425
- fill_kwargs = {"alpha" : alpha , "color" : color },
398
+ alpha = alpha ,
399
+ color = color ,
400
+ smooth_kwargs = smooth_kwargs ,
426
401
ax = axes [count ],
427
402
)
428
403
if smooth :
@@ -659,7 +634,7 @@ def _create_pdp_data(
659
634
def _smooth_mean (
660
635
new_x : npt .NDArray ,
661
636
p_di : npt .NDArray ,
662
- kind : str = "pdp " ,
637
+ kind : str = "neutral " ,
663
638
smooth_kwargs : Optional [dict [str , Any ]] = None ,
664
639
) -> tuple [np .ndarray , np .ndarray ]:
665
640
"""
@@ -688,7 +663,10 @@ def _smooth_mean(
688
663
smooth_kwargs .setdefault ("polyorder" , 2 )
689
664
x_data = np .linspace (np .nanmin (new_x ), np .nanmax (new_x ), 200 )
690
665
x_data [0 ] = (x_data [0 ] + x_data [1 ]) / 2
691
- if kind == "pdp" :
666
+
667
+ if kind == "neutral" :
668
+ interp = griddata (new_x , p_di , x_data )
669
+ elif kind == "pdp" :
692
670
interp = griddata (new_x , p_di .mean (0 ), x_data )
693
671
else :
694
672
interp = griddata (new_x , p_di .T , x_data )
@@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
800
778
801
779
802
780
def compute_variable_importance ( # noqa: PLR0915 PLR0912
803
- idata : az . InferenceData ,
781
+ idata : Any ,
804
782
bartrv : Variable ,
805
783
X : npt .NDArray ,
806
784
method : str = "VI" ,
@@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
904
882
[pearsonr2 (predicted_all [j ], predicted_subset [j ]) for j in range (samples )]
905
883
)
906
884
r2_mean [idx ] = np .mean (r_2 )
907
- r2_hdi [idx ] = az .hdi (r_2 )
885
+ r2_hdi [idx ] = array_stats .hdi (r_2 , prob = rcParams [ "stats.ci_prob" ] )
908
886
preds [idx ] = predicted_subset .squeeze ()
909
887
910
888
if method in ["backward" , "backward_VI" ]:
@@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
954
932
955
933
# Save values for plotting later
956
934
r2_mean [i_var - init ] = max_r_2
957
- r2_hdi [i_var - init ] = az .hdi (r_2_without_least_important_vars )
935
+ r2_hdi [i_var - init ] = array_stats .hdi (r_2_without_least_important_vars )
958
936
preds [i_var - init ] = least_important_samples .squeeze ()
959
937
960
938
# extend current list of least important variable
@@ -1079,7 +1057,7 @@ def plot_variable_importance(
1079
1057
)
1080
1058
ax .fill_between (
1081
1059
[- 0.5 , n_vars - 0.5 ],
1082
- * az .hdi (r_2_ref ),
1060
+ * array_stats .hdi (r_2_ref , prob = rcParams [ "stats.ci_prob" ] ),
1083
1061
alpha = 0.1 ,
1084
1062
color = plot_kwargs .get ("color_ref" , "grey" ),
1085
1063
)
@@ -1229,3 +1207,22 @@ def pearsonr2(A, B):
1229
1207
am = A - np .mean (A )
1230
1208
bm = B - np .mean (B )
1231
1209
return (am @ bm ) ** 2 / (np .sum (am ** 2 ) * np .sum (bm ** 2 ))
1210
+
1211
+
1212
+ def _plot_hdi (x , y , smooth , color , alpha , smooth_kwargs , ax ):
1213
+ x = np .asarray (x )
1214
+ y = np .asarray (y )
1215
+ hdi_prob = rcParams ["stats.ci_prob" ]
1216
+ hdi_data = array_stats .hdi (y , hdi_prob , axis = 0 )
1217
+ if smooth :
1218
+ if isinstance (x [0 ], np .datetime64 ):
1219
+ raise TypeError ("Cannot deal with x as type datetime. Recommend setting smooth=False." )
1220
+
1221
+ x_data , y_data = _smooth_mean (x , hdi_data , smooth_kwargs = smooth_kwargs )
1222
+ else :
1223
+ idx = np .argsort (x )
1224
+ x_data = x [idx ]
1225
+ y_data = hdi_data [idx ]
1226
+
1227
+ ax .fill_between (x_data , y_data [:, 0 ], y_data [:, 1 ], color = color , alpha = alpha )
1228
+ return ax
0 commit comments