Skip to content

Commit 6acc628

Browse files
authored
Better pairplots (#505)
* Hacky fix for pairplots * Ensure that target sits in front of other elements * Ensure consistent spacing between plot and legends + cleanup * Update docs * Fix the propagation of `legend_fontsize` * Minor fix to comply with code style
1 parent dc02245 commit 6acc628

File tree

3 files changed

+151
-37
lines changed

3 files changed

+151
-37
lines changed

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import seaborn as sns
88

99
from bayesflow.utils.dict_utils import dicts_to_arrays
10+
from bayesflow.utils.plot_utils import create_legends
1011

1112
from .pairs_samples import _pairs_samples
1213

@@ -21,6 +22,7 @@ def pairs_posterior(
2122
height: int = 3,
2223
post_color: str | tuple = "#132a70",
2324
prior_color: str | tuple = "gray",
25+
target_color: str | tuple = "red",
2426
alpha: float = 0.9,
2527
label_fontsize: int = 14,
2628
tick_fontsize: int = 12,
@@ -37,25 +39,27 @@ def pairs_posterior(
3739
Optional true parameter values that have generated the observed dataset.
3840
priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
3941
Optional prior samples obtained from the prior.
40-
dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated.
41-
Should only be specified if estimates contains posterior draws from multiple datasets.
42+
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
43+
Should only be specified if estimates contain posterior draws from multiple datasets.
4244
variable_keys : list or None, optional, default: None
4345
Select keys from the dictionary provided in samples.
4446
By default, select all keys.
4547
variable_names : list or None, optional, default: None
4648
The parameter names for nice plot titles. Inferred if None
4749
height : float, optional, default: 3
48-
The height of the pairplot
50+
The height of the pair plots
4951
label_fontsize : int, optional, default: 14
5052
The font size of the x and y-label texts (parameter names)
5153
tick_fontsize : int, optional, default: 12
52-
The font size of the axis ticklabels
54+
The font size of the axis tick labels
5355
legend_fontsize : int, optional, default: 16
5456
The font size of the legend text
5557
post_color : str, optional, default: '#132a70'
5658
The color for the posterior histograms and KDEs
5759
prior_color : str, optional, default: gray
5860
The color for the optional prior histograms and KDEs
61+
target_color : str, optional, default: red
62+
The color for the optional true parameter lines and points
5963
alpha : float in [0, 1], optional, default: 0.9
6064
The opacity of the posterior plots
6165
@@ -81,7 +85,7 @@ def pairs_posterior(
8185
variable_names=variable_names,
8286
)
8387

84-
# dicts_to_arrays will keep dataset axis even if it is of length 1
88+
# dicts_to_arrays will keep the dataset axis even if it is of length 1
8589
# however, pairs plotting requires the dataset axis to be removed
8690
estimates_shape = plot_data["estimates"].shape
8791
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
@@ -109,14 +113,30 @@ def pairs_posterior(
109113
# Create DataFrame with variable names as columns
110114
g.data = pd.DataFrame(targets, columns=targets.variable_names)
111115
g.data["_source"] = "True Parameter"
112-
g.map_diag(plot_true_params)
116+
g.map_diag(plot_true_params_as_lines, color=target_color)
117+
g.map_offdiag(plot_true_params_as_points, color=target_color)
118+
119+
create_legends(
120+
g,
121+
plot_data,
122+
color=post_color,
123+
color2=prior_color,
124+
legend_fontsize=legend_fontsize,
125+
show_single_legend=False,
126+
)
113127

114128
return g
115129

116130

117-
def plot_true_params(x, hue=None, **kwargs):
118-
"""Custom function to plot true parameters on the diagonal."""
131+
def plot_true_params_as_lines(x, hue=None, color=None, **kwargs):
132+
"""Custom function to plot true parameters on the diagonal as dashed lines."""
119133
# hue needs to be added to handle the case of plotting both posterior and prior
120134
param = x.iloc[0] # Get the single true value for the diagonal
121135
# only plot on the diagonal a vertical line for the true parameter
122-
plt.axvline(param, color="black", linestyle="--")
136+
plt.axvline(param, color=color, linestyle="--")
137+
138+
139+
def plot_true_params_as_points(x, y, color=None, marker="x", **kwargs):
140+
"""Custom function to plot true parameters on the off-diagonal as a single point."""
141+
if len(x) > 0 and len(y) > 0:
142+
plt.scatter(x.iloc[0], y.iloc[0], color=color, marker=marker, **kwargs)

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from bayesflow.utils import logging
1010
from bayesflow.utils.dict_utils import dicts_to_arrays
11+
from bayesflow.utils.plot_utils import create_legends
1112

1213

1314
def pairs_samples(
@@ -17,8 +18,10 @@ def pairs_samples(
1718
height: float = 2.5,
1819
color: str | tuple = "#132a70",
1920
alpha: float = 0.9,
21+
label: str = "Posterior",
2022
label_fontsize: int = 14,
2123
tick_fontsize: int = 12,
24+
show_single_legend: bool = False,
2225
**kwargs,
2326
) -> sns.PairGrid:
2427
"""
@@ -37,13 +40,18 @@ def pairs_samples(
3740
height : float, optional, default: 2.5
3841
The height of the pair plot
3942
color : str, optional, default : '#8f2727'
40-
The color of the plot
43+
The primary color of the plot
4144
alpha : float in [0, 1], optional, default: 0.9
4245
The opacity of the plot
46+
label : str, optional, default: "Posterior"
47+
Label for the dataset to plot
4348
label_fontsize : int, optional, default: 14
4449
The font size of the x and y-label texts (parameter names)
4550
tick_fontsize : int, optional, default: 12
46-
The font size of the axis ticklabels
51+
The font size of the axis tick labels
52+
show_single_legend : bool, optional, default: False
53+
Optional toggle for the user to choose whether a single dataset
54+
should also display legend
4755
**kwargs : dict, optional
4856
Additional keyword arguments passed to the sns.PairGrid constructor
4957
"""
@@ -59,8 +67,11 @@ def pairs_samples(
5967
height=height,
6068
color=color,
6169
alpha=alpha,
70+
label=label,
6271
label_fontsize=label_fontsize,
6372
tick_fontsize=tick_fontsize,
73+
show_single_legend=show_single_legend,
74+
**kwargs,
6475
)
6576

6677
return g
@@ -72,17 +83,27 @@ def _pairs_samples(
7283
color: str | tuple = "#132a70",
7384
color2: str | tuple = "gray",
7485
alpha: float = 0.9,
86+
label: str = "Posterior",
7587
label_fontsize: int = 14,
7688
tick_fontsize: int = 12,
7789
legend_fontsize: int = 14,
90+
show_single_legend: bool = False,
7891
**kwargs,
7992
) -> sns.PairGrid:
80-
# internal version of pairs_samples creating the seaborn plot
93+
"""
94+
Internal version of pairs_samples creating the seaborn PairPlot
95+
for both a single dataset and multiple datasets.
8196
82-
# Parameters
83-
# ----------
84-
# plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
85-
# other arguments are documented in pairs_samples
97+
Parameters
98+
----------
99+
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100+
Formatted data to plot from the sample dataset
101+
color2 : str, optional, default: 'gray'
102+
Secondary color for the pair plots.
103+
This is the color used for the prior draws.
104+
105+
Other arguments are documented in pairs_samples
106+
"""
86107

87108
estimates_shape = plot_data["estimates"].shape
88109
if len(estimates_shape) != 2:
@@ -136,7 +157,7 @@ def _pairs_samples(
136157
common_norm=False,
137158
)
138159

139-
# add scatterplots to the upper diagonal
160+
# add scatter plots to the upper diagonal
140161
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
141162

142163
# add KDEs to the lower diagonal
@@ -146,11 +167,6 @@ def _pairs_samples(
146167
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
147168
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
148169

149-
# need to add legend here such that colors are recognized
150-
if plot_data["priors"] is not None:
151-
g.add_legend(fontsize=legend_fontsize, loc="center right")
152-
g._legend.set_title(None)
153-
154170
# Generate grids
155171
dim = g.axes.shape[0]
156172
for i in range(dim):
@@ -165,32 +181,48 @@ def _pairs_samples(
165181
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
166182
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
167183

168-
# adjust font size of labels
184+
# adjust the font size of labels
169185
# the labels themselves remain the same as before, i.e., variable_names
170186
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
171187
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
172188

189+
# need to add legend here such that colors are recognized
190+
# if plot_data["priors"] is not None:
191+
# g.add_legend(fontsize=legend_fontsize, loc="center right")
192+
# g._legend.set_title(None)
193+
194+
create_legends(
195+
g,
196+
plot_data,
197+
color=color,
198+
color2=color2,
199+
legend_fontsize=legend_fontsize,
200+
label=label,
201+
show_single_legend=show_single_legend,
202+
)
203+
173204
# Return figure
174205
g.tight_layout()
175206

176207
return g
177208

178209

179-
# create a histogram plot on a twin y axis
180-
# this ensures that the y scaling of the diagonal plots
181-
# in independent of the y scaling of the off-diagonal plots
182210
def histplot_twinx(x, **kwargs):
183-
# Create a twin axis
184-
ax2 = plt.gca().twinx()
211+
"""
212+
# create a histogram plot on a twin y-axis
213+
# this ensures that the y scaling of the diagonal plots
214+
# in independent of the y scaling of the off-diagonal plots
185215
216+
Parameters
217+
----------
218+
x : np.ndarray
219+
Data to be plotted.
220+
"""
186221
# create a histogram on the twin axis
187-
sns.histplot(x, **kwargs, ax=ax2)
222+
sns.histplot(x, legend=False, **kwargs)
188223

189224
# make the twin axis invisible
190225
plt.gca().spines["right"].set_visible(False)
191226
plt.gca().spines["top"].set_visible(False)
192-
ax2.set_ylabel("")
193-
ax2.set_yticks([])
194-
ax2.set_yticklabels([])
195227

196228
return None

bayesflow/utils/plot_utils.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from matplotlib.collections import LineCollection
88
from matplotlib.colors import Normalize
9-
from matplotlib.patches import Rectangle
9+
from matplotlib.patches import Rectangle, Patch
1010
from matplotlib.legend_handler import HandlerPatch
1111

1212
from .validators import check_estimates_prior_shapes
@@ -67,7 +67,7 @@ def prepare_plot_data(
6767
)
6868
check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"])
6969

70-
# store variable information at top level for easy access
70+
# store variable information at the top level for easy access
7171
variable_names = plot_data["estimates"].variable_names
7272
num_variables = len(variable_names)
7373
plot_data["variable_names"] = variable_names
@@ -249,7 +249,7 @@ def prettify_subplots(axes: np.ndarray, num_subplots: int, tick: bool = True, ti
249249

250250
def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
251251
"""
252-
Utility to make a subplots quadratic in order to avoid visual illusions
252+
Utility to make subplots quadratic to avoid visual illusions
253253
in, e.g., recovery plots.
254254
"""
255255

@@ -269,7 +269,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
269269

270270
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
271271
"""
272-
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
272+
Plot a 1D line with a color gradient determined by `c` (same shape as x and y).
273273
"""
274274
if ax is None:
275275
ax = plt.gca()
@@ -304,7 +304,7 @@ def gradient_legend(ax, label, cmap, norm, loc="upper right"):
304304
- loc: legend location (default 'upper right')
305305
"""
306306

307-
# Custom dummy handle to represent the gradient
307+
# Custom placeholder handle to represent the gradient
308308
class _GradientSwatch(Rectangle):
309309
pass
310310

@@ -358,3 +358,65 @@ def add_gradient_plot(
358358
label=label,
359359
alpha=0.01,
360360
)
361+
362+
363+
def create_legends(
364+
g,
365+
plot_data: dict,
366+
color: str | tuple = "#132a70",
367+
color2: str | tuple = "gray",
368+
label: str = "Posterior",
369+
show_single_legend: bool = False,
370+
legend_fontsize: int = 14,
371+
):
372+
"""
373+
Helper function to create legends for pairplots.
374+
375+
Parameters
376+
----------
377+
g : sns.PairGrid
378+
Seaborn object for the pair plots
379+
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
380+
Formatted data to plot from the sample dataset
381+
color : str, optional, default : '#8f2727'
382+
The primary color of the plot
383+
color2 : str, optional, default: 'gray'
384+
The secondary color for the plot
385+
label : str, optional, default: "Posterior"
386+
Label for the dataset to plot
387+
show_single_legend : bool, optional, default: False
388+
Optional toggle for the user to choose whether a single dataset
389+
should also display legend
390+
legend_fontsize : int, optional, default: 14
391+
fontsize for the legend
392+
"""
393+
handles = []
394+
labels = []
395+
396+
if plot_data.get("priors") is not None:
397+
prior_handle = Patch(color=color2, label="Prior")
398+
prior_label = "Prior"
399+
handles.append(prior_handle)
400+
labels.append(prior_label)
401+
402+
posterior_handle = Patch(color=color, label="Posterior")
403+
posterior_label = label
404+
handles.append(posterior_handle)
405+
labels.append(posterior_label)
406+
407+
if plot_data.get("targets") is not None:
408+
target_handle = plt.Line2D([0], [0], color="r", linestyle="--", marker="x", label="Targets")
409+
target_label = "Targets"
410+
handles.append(target_handle)
411+
labels.append(target_label)
412+
413+
# If there are more than one dataset to plot,
414+
if len(handles) > 1 or show_single_legend:
415+
g.figure.legend(
416+
handles=handles,
417+
labels=labels,
418+
loc="center left",
419+
bbox_to_anchor=(1, 0.5),
420+
frameon=False,
421+
fontsize=legend_fontsize,
422+
)

0 commit comments

Comments
 (0)