Skip to content

Commit 67db53b

Browse files
committed
Reduce matplotlib multithreading errors with Agg backend
1 parent 1e3e8ff commit 67db53b

File tree

7 files changed

+199
-177
lines changed

7 files changed

+199
-177
lines changed

slideflow/heatmap.py

+53-52
Original file line numberDiff line numberDiff line change
@@ -740,70 +740,71 @@ def save(
740740
"""
741741
import matplotlib.pyplot as plt
742742

743-
if self.predictions is None:
744-
raise errors.HeatmapError(
745-
"Cannot plot Heatmap which is not yet generated; generate with "
746-
"either heatmap.generate() or Heatmap(..., generate=True)"
747-
)
743+
with sf.util.matplotlib_backend('Agg'):
744+
if self.predictions is None:
745+
raise errors.HeatmapError(
746+
"Cannot plot Heatmap which is not yet generated; generate with "
747+
"either heatmap.generate() or Heatmap(..., generate=True)"
748+
)
748749

749-
# Save heatmaps in .npz format
750-
self.save_npz(os.path.join(outdir, f'{self.slide.name}.npz'))
750+
# Save heatmaps in .npz format
751+
self.save_npz(os.path.join(outdir, f'{self.slide.name}.npz'))
751752

752-
def _savefig(label, bbox_inches='tight', **kwargs):
753-
plt.savefig(
754-
os.path.join(outdir, f'{self.slide.name}-{label}.png'),
755-
bbox_inches=bbox_inches,
756-
**kwargs
757-
)
753+
def _savefig(label, bbox_inches='tight', **kwargs):
754+
plt.savefig(
755+
os.path.join(outdir, f'{self.slide.name}-{label}.png'),
756+
bbox_inches=bbox_inches,
757+
**kwargs
758+
)
758759

759-
log.info('Saving base figures...')
760+
log.info('Saving base figures...')
760761

761-
# Prepare matplotlib figure
762-
ax = self._prepare_ax()
762+
# Prepare matplotlib figure
763+
ax = self._prepare_ax()
763764

764-
thumb_kwargs = dict(roi_color=roi_color, linewidth=linewidth)
765+
thumb_kwargs = dict(roi_color=roi_color, linewidth=linewidth)
765766

766-
# Save base thumbnail as separate figure
767-
self.plot_thumbnail(show_roi=False, ax=ax, **thumb_kwargs) # type: ignore
768-
_savefig('raw')
767+
# Save base thumbnail as separate figure
768+
self.plot_thumbnail(show_roi=False, ax=ax, **thumb_kwargs) # type: ignore
769+
_savefig('raw')
769770

770-
# Save thumbnail + ROI as separate figure
771-
self.plot_thumbnail(show_roi=True, ax=ax, **thumb_kwargs) # type: ignore
772-
_savefig('raw+roi')
771+
# Save thumbnail + ROI as separate figure
772+
self.plot_thumbnail(show_roi=True, ax=ax, **thumb_kwargs) # type: ignore
773+
_savefig('raw+roi')
773774

774-
if logit_cmap:
775-
self.plot_with_logit_cmap(logit_cmap, show_roi=show_roi, ax=ax)
776-
_savefig('custom')
777-
else:
778-
heatmap_kwargs = dict(
779-
show_roi=show_roi,
780-
interpolation=interpolation,
781-
**kwargs
782-
)
783-
save_kwargs = dict(
784-
bbox_inches='tight',
785-
facecolor=ax.get_facecolor(),
786-
edgecolor='none'
787-
)
788-
# Make heatmap plots and sliders for each outcome category
789-
for i in range(self.num_classes):
790-
log.info(f'Making {i+1}/{self.num_classes}...')
791-
self.plot(i, heatmap_alpha=0.6, ax=ax, **heatmap_kwargs)
792-
_savefig(str(i), **save_kwargs)
775+
if logit_cmap:
776+
self.plot_with_logit_cmap(logit_cmap, show_roi=show_roi, ax=ax)
777+
_savefig('custom')
778+
else:
779+
heatmap_kwargs = dict(
780+
show_roi=show_roi,
781+
interpolation=interpolation,
782+
**kwargs
783+
)
784+
save_kwargs = dict(
785+
bbox_inches='tight',
786+
facecolor=ax.get_facecolor(),
787+
edgecolor='none'
788+
)
789+
# Make heatmap plots and sliders for each outcome category
790+
for i in range(self.num_classes):
791+
log.info(f'Making {i+1}/{self.num_classes}...')
792+
self.plot(i, heatmap_alpha=0.6, ax=ax, **heatmap_kwargs)
793+
_savefig(str(i), **save_kwargs)
793794

794-
self.plot(i, heatmap_alpha=1, ax=ax, **heatmap_kwargs)
795-
_savefig(f'{i}-solid', **save_kwargs)
795+
self.plot(i, heatmap_alpha=1, ax=ax, **heatmap_kwargs)
796+
_savefig(f'{i}-solid', **save_kwargs)
796797

797-
# Uncertainty map
798-
if self.uq:
799-
log.info('Making uncertainty heatmap...')
800-
self.plot_uncertainty(heatmap_alpha=0.6, ax=ax, **heatmap_kwargs)
801-
_savefig('UQ', **save_kwargs)
798+
# Uncertainty map
799+
if self.uq:
800+
log.info('Making uncertainty heatmap...')
801+
self.plot_uncertainty(heatmap_alpha=0.6, ax=ax, **heatmap_kwargs)
802+
_savefig('UQ', **save_kwargs)
802803

803-
self.plot_uncertainty(heatmap_alpha=1, ax=ax, **heatmap_kwargs)
804-
_savefig('UQ-solid', **save_kwargs)
804+
self.plot_uncertainty(heatmap_alpha=1, ax=ax, **heatmap_kwargs)
805+
_savefig('UQ-solid', **save_kwargs)
805806

806-
plt.close()
807+
plt.close()
807808
log.info(f'Saved heatmaps for [green]{self.slide.name}')
808809

809810
def view(self):

slideflow/mosaic.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -596,16 +596,17 @@ def save(self, filename: str, **kwargs: Any) -> None:
596596
"""
597597
import matplotlib.pyplot as plt
598598

599-
self.plot(**kwargs)
600-
log.info('Exporting figure...')
601-
try:
602-
if not os.path.exists(os.path.dirname(filename)):
603-
os.makedirs(os.path.dirname(filename))
604-
except FileNotFoundError:
605-
pass
606-
plt.savefig(filename, bbox_inches='tight')
607-
log.info(f'Saved figure to [green]{filename}')
608-
plt.close()
599+
with sf.util.matplotlib_backend('Agg'):
600+
self.plot(**kwargs)
601+
log.info('Exporting figure...')
602+
try:
603+
if not os.path.exists(os.path.dirname(filename)):
604+
os.makedirs(os.path.dirname(filename))
605+
except FileNotFoundError:
606+
pass
607+
plt.savefig(filename, bbox_inches='tight')
608+
log.info(f'Saved figure to [green]{filename}')
609+
plt.close()
609610

610611
def save_report(self, filename: str) -> None:
611612
"""Saves a report of which tiles (and their corresponding slide)

slideflow/slide/report.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -319,17 +319,18 @@ def __init__(
319319
if b is not None and b > self.bb_threshold:
320320
self.warn_txt += f'{slide},{b}\n'
321321

322-
if np.any(n_tiles) and self.num_tiles_chart(n_tiles):
323-
with tempfile.NamedTemporaryFile(suffix='.png') as temp:
324-
plt.savefig(temp.name)
325-
pdf.image(temp.name, 107, pdf.y, w=50)
326-
plt.close()
327-
328-
if np.any(bb) and self.blur_chart(bb):
329-
with tempfile.NamedTemporaryFile(suffix='.png') as temp:
330-
plt.savefig(temp.name)
331-
pdf.image(temp.name, 155, pdf.y, w=50)
332-
plt.close()
322+
with sf.util.matplotlib_backend('Agg'):
323+
if np.any(n_tiles) and self.num_tiles_chart(n_tiles):
324+
with tempfile.NamedTemporaryFile(suffix='.png') as temp:
325+
plt.savefig(temp.name)
326+
pdf.image(temp.name, 107, pdf.y, w=50)
327+
plt.close()
328+
329+
if np.any(bb) and self.blur_chart(bb):
330+
with tempfile.NamedTemporaryFile(suffix='.png') as temp:
331+
plt.savefig(temp.name)
332+
pdf.image(temp.name, 155, pdf.y, w=50)
333+
plt.close()
333334

334335
# Bounding box
335336
pdf.set_x(20)
@@ -451,11 +452,12 @@ def num_tiles_chart(self, num_tiles: np.ndarray) -> bool:
451452
import matplotlib.pyplot as plt
452453
import seaborn as sns
453454
if np.any(num_tiles):
454-
plt.rc('font', size=14)
455-
sns.histplot(num_tiles, bins=20)
456-
plt.title('Number of tiles extracted')
457-
plt.ylabel('Number of slides', fontsize=16, fontname='Arial')
458-
plt.xlabel('Tiles extracted', fontsize=16, fontname='Arial')
455+
with sf.util.matplotlib_backend('Agg'):
456+
plt.rc('font', size=14)
457+
sns.histplot(num_tiles, bins=20)
458+
plt.title('Number of tiles extracted')
459+
plt.ylabel('Number of slides', fontsize=16, fontname='Arial')
460+
plt.xlabel('Tiles extracted', fontsize=16, fontname='Arial')
459461
return True
460462
else:
461463
return False
@@ -472,12 +474,14 @@ def blur_chart(self, blur_arr: np.ndarray) -> bool:
472474
with np.errstate(divide='ignore'):
473475
log_b = np.log(blur_arr)
474476
log_b = log_b[np.isfinite(log_b)]
475-
plt.rc('font', size=14)
476-
sns.histplot(log_b, bins=20)
477-
plt.title('Quality Control: Blur Burden'+warn_txt)
478-
plt.ylabel('Count', fontsize=16, fontname='Arial')
479-
plt.xlabel('log(blur burden)', fontsize=16, fontname='Arial')
480-
plt.axvline(x=-3, color='r', linestyle='--')
477+
478+
with sf.util.matplotlib_backend('Agg'):
479+
plt.rc('font', size=14)
480+
sns.histplot(log_b, bins=20)
481+
plt.title('Quality Control: Blur Burden'+warn_txt)
482+
plt.ylabel('Count', fontsize=16, fontname='Arial')
483+
plt.xlabel('log(blur burden)', fontsize=16, fontname='Arial')
484+
plt.axvline(x=-3, color='r', linestyle='--')
481485
return True
482486
else:
483487
return False

slideflow/stats/metrics.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,23 @@ def prc_fit(self):
8181

8282
def save_roc(self, outdir, name):
8383
import matplotlib.pyplot as plt
84-
auroc_str = 'NA' if not self.auroc else f'{self.auroc:.2f}'
85-
sf.stats.plot.roc(self.fpr, self.tpr, f'AUC = {auroc_str}')
86-
full_path = join(outdir, f'{name}.png')
87-
plt.savefig(full_path)
88-
plt.close()
84+
with sf.util.matplotlib_backend('Agg'):
85+
auroc_str = 'NA' if not self.auroc else f'{self.auroc:.2f}'
86+
sf.stats.plot.roc(self.fpr, self.tpr, f'AUC = {auroc_str}')
87+
full_path = join(outdir, f'{name}.png')
88+
plt.savefig(full_path)
89+
plt.close()
8990
if self.neptune_run:
9091
self.neptune_run[f'results/graphs/{name}'].upload(full_path)
9192

9293
def save_prc(self, outdir, name):
9394
import matplotlib.pyplot as plt
94-
ap_str = 'NA' if not self.ap else f'{self.ap:.2f}'
95-
sf.stats.plot.prc(self.precision, self.recall, label=f'AP = {ap_str}')
96-
full_path = join(outdir, f'{name}.png')
97-
plt.savefig(full_path)
98-
plt.close()
95+
with sf.util.matplotlib_backend('Agg'):
96+
ap_str = 'NA' if not self.ap else f'{self.ap:.2f}'
97+
sf.stats.plot.prc(self.precision, self.recall, label=f'AP = {ap_str}')
98+
full_path = join(outdir, f'{name}.png')
99+
plt.savefig(full_path)
100+
plt.close()
99101
if self.neptune_run:
100102
self.neptune_run[f'results/graphs/{name}'].upload(full_path)
101103

slideflow/stats/plot.py

+37-34
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import slideflow as sf
12
import os
23
import warnings
34
from typing import TYPE_CHECKING, List, Optional, Union
@@ -36,28 +37,29 @@ def combined_roc(
3637
"""
3738
import matplotlib.pyplot as plt
3839

39-
plt.clf()
40-
plt.title(name)
41-
colors = ('b', 'g', 'r', 'c', 'm', 'y', 'k')
42-
aurocs = []
43-
for i, (yt, yp) in enumerate(zip(y_true, y_pred)):
44-
fpr, tpr, threshold = metrics.roc_curve(yt, yp)
45-
roc_auc = metrics.auc(fpr, tpr)
46-
aurocs += [roc_auc]
47-
label = f'{labels[i]} (AUC: {roc_auc:.2f})'
48-
plt.plot(fpr, tpr, colors[i % len(colors)], label=label)
49-
plt.legend(loc='lower right')
50-
plt.plot([0, 1], [0, 1], 'r--')
51-
plt.xlim([0, 1])
52-
plt.ylim([0, 1])
53-
plt.ylabel('TPR')
54-
plt.xlabel('FPR')
55-
plt.savefig(os.path.join(save_dir, f'{name}.png'))
56-
if neptune_run:
57-
neptune_run[f'results/graphs/{name}'].upload(
58-
os.path.join(save_dir, f'{name}.png')
59-
)
60-
plt.close()
40+
with sf.util.matplotlib_backend('Agg'):
41+
plt.clf()
42+
plt.title(name)
43+
colors = ('b', 'g', 'r', 'c', 'm', 'y', 'k')
44+
aurocs = []
45+
for i, (yt, yp) in enumerate(zip(y_true, y_pred)):
46+
fpr, tpr, threshold = metrics.roc_curve(yt, yp)
47+
roc_auc = metrics.auc(fpr, tpr)
48+
aurocs += [roc_auc]
49+
label = f'{labels[i]} (AUC: {roc_auc:.2f})'
50+
plt.plot(fpr, tpr, colors[i % len(colors)], label=label)
51+
plt.legend(loc='lower right')
52+
plt.plot([0, 1], [0, 1], 'r--')
53+
plt.xlim([0, 1])
54+
plt.ylim([0, 1])
55+
plt.ylabel('TPR')
56+
plt.xlabel('FPR')
57+
plt.savefig(os.path.join(save_dir, f'{name}.png'))
58+
if neptune_run:
59+
neptune_run[f'results/graphs/{name}'].upload(
60+
os.path.join(save_dir, f'{name}.png')
61+
)
62+
plt.close()
6163
return aurocs
6264

6365

@@ -176,16 +178,17 @@ def scatter(
176178
yp_sub = y_pred
177179

178180
# Perform scatter for each outcome
179-
for i in range(y_true.shape[1]):
180-
r_squared += [metrics.r2_score(y_true[:, i], y_pred[:, i])]
181-
with warnings.catch_warnings():
182-
warnings.filterwarnings("ignore", category=UserWarning)
183-
p = sns.jointplot(x=yt_sub[:, i], y=yp_sub[:, i], kind="reg")
184-
p.set_axis_labels('y_true', 'y_pred')
185-
plt.savefig(os.path.join(data_dir, f'Scatter{name}-{i}.png'))
186-
if neptune_run:
187-
neptune_run[f'results/graphs/Scatter{name}-{i}'].upload(
188-
os.path.join(data_dir, f'Scatter{name}-{i}.png')
189-
)
190-
plt.close()
181+
with sf.util.matplotlib_backend('Agg'):
182+
for i in range(y_true.shape[1]):
183+
r_squared += [metrics.r2_score(y_true[:, i], y_pred[:, i])]
184+
with warnings.catch_warnings():
185+
warnings.filterwarnings("ignore", category=UserWarning)
186+
p = sns.jointplot(x=yt_sub[:, i], y=yp_sub[:, i], kind="reg")
187+
p.set_axis_labels('y_true', 'y_pred')
188+
plt.savefig(os.path.join(data_dir, f'Scatter{name}-{i}.png'))
189+
if neptune_run:
190+
neptune_run[f'results/graphs/Scatter{name}-{i}'].upload(
191+
os.path.join(data_dir, f'Scatter{name}-{i}.png')
192+
)
193+
plt.close()
191194
return r_squared

slideflow/stats/slidemap.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -984,9 +984,10 @@ def save_plot(
984984
"""
985985
import matplotlib.pyplot as plt
986986

987-
self.plot(**kwargs)
988-
plt.savefig(filename, bbox_inches='tight', dpi=dpi)
989-
plt.close()
987+
with sf.util.matplotlib_backend('Agg'):
988+
self.plot(**kwargs)
989+
plt.savefig(filename, bbox_inches='tight', dpi=dpi)
990+
plt.close()
990991
log.info(f"Saved 2D UMAP to [green]{filename}")
991992

992993
def save_3d(
@@ -1013,9 +1014,10 @@ def save_3d(
10131014
"""
10141015
import matplotlib.pyplot as plt
10151016

1016-
self.plot_3d(**kwargs)
1017-
plt.savefig(filename, bbox_inches='tight', dpi=dpi)
1018-
plt.close()
1017+
with sf.util.matplotlib_backend('Agg'):
1018+
self.plot_3d(**kwargs)
1019+
plt.savefig(filename, bbox_inches='tight', dpi=dpi)
1020+
plt.close()
10191021
log.info(f"Saved 3D UMAP to [green]{filename}")
10201022

10211023
def save_coordinates(self, path: str) -> None:

0 commit comments

Comments
 (0)