8
8
9
9
from bayesflow .utils import logging
10
10
from bayesflow .utils .dict_utils import dicts_to_arrays
11
+ from bayesflow .utils .plot_utils import create_legends
11
12
12
13
13
14
def pairs_samples (
@@ -17,8 +18,10 @@ def pairs_samples(
17
18
height : float = 2.5 ,
18
19
color : str | tuple = "#132a70" ,
19
20
alpha : float = 0.9 ,
21
+ label : str = "Posterior" ,
20
22
label_fontsize : int = 14 ,
21
23
tick_fontsize : int = 12 ,
24
+ show_single_legend : bool = False ,
22
25
** kwargs ,
23
26
) -> sns .PairGrid :
24
27
"""
@@ -37,13 +40,18 @@ def pairs_samples(
37
40
height : float, optional, default: 2.5
38
41
The height of the pair plot
39
42
color : str, optional, default : '#8f2727'
40
- The color of the plot
43
+ The primary color of the plot
41
44
alpha : float in [0, 1], optional, default: 0.9
42
45
The opacity of the plot
46
+ label : str, optional, default: "Posterior"
47
+ Label for the dataset to plot
43
48
label_fontsize : int, optional, default: 14
44
49
The font size of the x and y-label texts (parameter names)
45
50
tick_fontsize : int, optional, default: 12
46
- The font size of the axis ticklabels
51
+ The font size of the axis tick labels
52
+ show_single_legend : bool, optional, default: False
53
+ Optional toggle for the user to choose whether a single dataset
54
+ should also display legend
47
55
**kwargs : dict, optional
48
56
Additional keyword arguments passed to the sns.PairGrid constructor
49
57
"""
@@ -59,8 +67,11 @@ def pairs_samples(
59
67
height = height ,
60
68
color = color ,
61
69
alpha = alpha ,
70
+ label = label ,
62
71
label_fontsize = label_fontsize ,
63
72
tick_fontsize = tick_fontsize ,
73
+ show_single_legend = show_single_legend ,
74
+ ** kwargs ,
64
75
)
65
76
66
77
return g
@@ -72,17 +83,27 @@ def _pairs_samples(
72
83
color : str | tuple = "#132a70" ,
73
84
color2 : str | tuple = "gray" ,
74
85
alpha : float = 0.9 ,
86
+ label : str = "Posterior" ,
75
87
label_fontsize : int = 14 ,
76
88
tick_fontsize : int = 12 ,
77
89
legend_fontsize : int = 14 ,
90
+ show_single_legend : bool = False ,
78
91
** kwargs ,
79
92
) -> sns .PairGrid :
80
- # internal version of pairs_samples creating the seaborn plot
93
+ """
94
+ Internal version of pairs_samples creating the seaborn PairPlot
95
+ for both a single dataset and multiple datasets.
81
96
82
- # Parameters
83
- # ----------
84
- # plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
85
- # other arguments are documented in pairs_samples
97
+ Parameters
98
+ ----------
99
+ plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100
+ Formatted data to plot from the sample dataset
101
+ color2 : str, optional, default: 'gray'
102
+ Secondary color for the pair plots.
103
+ This is the color used for the prior draws.
104
+
105
+ Other arguments are documented in pairs_samples
106
+ """
86
107
87
108
estimates_shape = plot_data ["estimates" ].shape
88
109
if len (estimates_shape ) != 2 :
@@ -136,7 +157,7 @@ def _pairs_samples(
136
157
common_norm = False ,
137
158
)
138
159
139
- # add scatterplots to the upper diagonal
160
+ # add scatter plots to the upper diagonal
140
161
g .map_upper (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
141
162
142
163
# add KDEs to the lower diagonal
@@ -146,11 +167,6 @@ def _pairs_samples(
146
167
logging .exception ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
147
168
g .map_lower (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
148
169
149
- # need to add legend here such that colors are recognized
150
- if plot_data ["priors" ] is not None :
151
- g .add_legend (fontsize = legend_fontsize , loc = "center right" )
152
- g ._legend .set_title (None )
153
-
154
170
# Generate grids
155
171
dim = g .axes .shape [0 ]
156
172
for i in range (dim ):
@@ -165,32 +181,48 @@ def _pairs_samples(
165
181
g .axes [i , j ].tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
166
182
g .axes [i , j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
167
183
168
- # adjust font size of labels
184
+ # adjust the font size of labels
169
185
# the labels themselves remain the same as before, i.e., variable_names
170
186
g .axes [i , 0 ].set_ylabel (variable_names [i ], fontsize = label_fontsize )
171
187
g .axes [dim - 1 , i ].set_xlabel (variable_names [i ], fontsize = label_fontsize )
172
188
189
+ # need to add legend here such that colors are recognized
190
+ # if plot_data["priors"] is not None:
191
+ # g.add_legend(fontsize=legend_fontsize, loc="center right")
192
+ # g._legend.set_title(None)
193
+
194
+ create_legends (
195
+ g ,
196
+ plot_data ,
197
+ color = color ,
198
+ color2 = color2 ,
199
+ legend_fontsize = legend_fontsize ,
200
+ label = label ,
201
+ show_single_legend = show_single_legend ,
202
+ )
203
+
173
204
# Return figure
174
205
g .tight_layout ()
175
206
176
207
return g
177
208
178
209
179
- # create a histogram plot on a twin y axis
180
- # this ensures that the y scaling of the diagonal plots
181
- # in independent of the y scaling of the off-diagonal plots
182
210
def histplot_twinx (x , ** kwargs ):
183
- # Create a twin axis
184
- ax2 = plt .gca ().twinx ()
211
+ """
212
+ # create a histogram plot on a twin y-axis
213
+ # this ensures that the y scaling of the diagonal plots
214
+ # in independent of the y scaling of the off-diagonal plots
185
215
216
+ Parameters
217
+ ----------
218
+ x : np.ndarray
219
+ Data to be plotted.
220
+ """
186
221
# create a histogram on the twin axis
187
- sns .histplot (x , ** kwargs , ax = ax2 )
222
+ sns .histplot (x , legend = False , ** kwargs )
188
223
189
224
# make the twin axis invisible
190
225
plt .gca ().spines ["right" ].set_visible (False )
191
226
plt .gca ().spines ["top" ].set_visible (False )
192
- ax2 .set_ylabel ("" )
193
- ax2 .set_yticks ([])
194
- ax2 .set_yticklabels ([])
195
227
196
228
return None
0 commit comments