Skip to content

Commit 0a6d0ad

Browse files
committed
Fix Tensorboard logging with Tensorflow [#224]
- Fixes Tensorboard logging to include both training and validation logs, at the batch and epoch level, when using the Tensorboard backend. Validation logs were previously not being written due to a change in how the Tensorflow/Keras `Model.fit()` API interacted with Tensorboard callbacks. - Add an informative error message to an assertion check in `Dataset.train_val_split()` - All tests passed
1 parent 1174263 commit 0a6d0ad

File tree

2 files changed

+155
-96
lines changed

2 files changed

+155
-96
lines changed

slideflow/dataset.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,9 +2812,18 @@ def train_val_split(
28122812
tfr for tfr in tfrecord_dir_list
28132813
if path_to_name(tfr) in train_slides
28142814
]
2815-
2816-
assert(len(val_tfrecords) == len(val_slides))
2817-
assert(len(training_tfrecords) == len(train_slides))
2815+
if not len(val_tfrecords) == len(val_slides):
2816+
raise errors.DatasetError(
2817+
f"Number of validation tfrecords ({len(val_tfrecords)}) does not "
2818+
f"match the number of validation slides ({len(val_slides)}). "
2819+
"This may happen if multiple tfrecords were found for some slides."
2820+
)
2821+
if not len(training_tfrecords) == len(train_slides):
2822+
raise errors.DatasetError(
2823+
f"Number of training tfrecords ({len(val_tfrecords)}) does not "
2824+
f"match the number of training slides ({len(val_slides)}). "
2825+
"This may happen if multiple tfrecords were found for some slides."
2826+
)
28182827
training_dts = copy.deepcopy(self)
28192828
training_dts = training_dts.filter(filters={'slide': train_slides})
28202829
val_dts = copy.deepcopy(self)

slideflow/model/tensorflow.py

Lines changed: 143 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,10 @@ def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
652652
self.results = {'epochs': {}} # type: Dict[str, Dict]
653653
self.neptune_run = self.parent.neptune_run
654654
self.global_step = 0
655+
self.train_summary_writer = tf.summary.create_file_writer(
656+
join(self.parent.outdir, 'train'))
657+
self.val_summary_writer = tf.summary.create_file_writer(
658+
join(self.parent.outdir, 'validation'))
655659

656660
# Circumvents buffer overflow error with Python 3.10.
657661
# Without this, a buffer overflow error will be encountered when
@@ -663,6 +667,128 @@ def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
663667
plt.figure()
664668
plt.close()
665669

670+
def _log_training_metrics(self, logs):
671+
"""Log training metrics to Tensorboard/Neptune."""
672+
# Log to Tensorboard.
673+
for _log in logs:
674+
tf.summary.scalar(
675+
f'batch_{_log}',
676+
data=logs[_log],
677+
step=self.global_step)
678+
# Log to neptune.
679+
if self.neptune_run:
680+
self.neptune_run['metrics/train/batch/loss'].log(
681+
logs['loss'],
682+
step=self.global_step)
683+
sf.util.neptune_utils.list_log(
684+
self.neptune_run,
685+
'metrics/train/batch/accuracy',
686+
logs['accuracy'],
687+
step=self.global_step)
688+
689+
def _log_validation_metrics(self, metrics):
690+
"""Log validation metrics to Tensorboard/Neptune."""
691+
# Tensorboard logging for validation metrics
692+
with self.val_summary_writer.as_default():
693+
for _log in metrics:
694+
tf.summary.scalar(
695+
f'batch_{_log}',
696+
data=metrics[_log],
697+
step=self.global_step)
698+
# Log to neptune
699+
if self.neptune_run:
700+
for v in metrics:
701+
self.neptune_run[f"metrics/val/batch/{v}"].log(
702+
round(metrics[v], 3),
703+
step=self.global_step
704+
)
705+
if self.last_ema != -1:
706+
self.neptune_run["metrics/val/batch/exp_moving_avg"].log(
707+
round(self.last_ema, 3),
708+
step=self.global_step
709+
)
710+
self.neptune_run["early_stop/stopped_early"] = False
711+
712+
def _log_epoch_evaluation(self, epoch_results, metrics, accuracy, loss, logs={}):
713+
"""Log the end-of-epoch evaluation to CSV, Tensorboard, and Neptune."""
714+
epoch = self.epoch_count
715+
run = self.neptune_run
716+
sf.util.update_results_log(
717+
self.cb_args.results_log,
718+
'trained_model',
719+
{f'epoch{epoch}': epoch_results}
720+
)
721+
with self.val_summary_writer.as_default():
722+
# Note: Tensorboard epoch logging starts with index=0,
723+
# whereas all other logging starts with index=1
724+
if isinstance(accuracy, (list, tuple, np.ndarray)):
725+
for i in range(len(accuracy)):
726+
tf.summary.scalar(f'epoch_accuracy-{i}', data=accuracy[i], step=epoch-1)
727+
elif accuracy is not None:
728+
tf.summary.scalar(f'epoch_accuracy', data=accuracy, step=epoch-1)
729+
if isinstance(loss, (list, tuple, np.ndarray)):
730+
for i in range(len(loss)):
731+
tf.summary.scalar(f'epoch_loss-{i}', data=loss[i], step=epoch-1)
732+
else:
733+
tf.summary.scalar(f'epoch_loss', data=loss, step=epoch-1)
734+
735+
# Log epoch results to Neptune
736+
if run:
737+
# Training epoch metrics
738+
run['metrics/train/epoch/loss'].log(logs['loss'], step=epoch)
739+
sf.util.neptune_utils.list_log(
740+
run,
741+
'metrics/train/epoch/accuracy',
742+
logs['accuracy'],
743+
step=epoch
744+
)
745+
# Validation epoch metrics
746+
run['metrics/val/epoch/loss'].log(loss, step=epoch)
747+
sf.util.neptune_utils.list_log(
748+
run,
749+
'metrics/val/epoch/accuracy',
750+
accuracy,
751+
step=epoch
752+
)
753+
for metric in metrics:
754+
if metrics[metric]['tile'] is None:
755+
continue
756+
for outcome in metrics[metric]['tile']:
757+
# If only one outcome, log to metrics/val/epoch/[metric].
758+
# If more than one outcome, log to
759+
# metrics/val/epoch/[metric]/[outcome_name]
760+
def metric_label(s):
761+
if len(metrics[metric]['tile']) == 1:
762+
return f'metrics/val/epoch/{s}_{metric}'
763+
else:
764+
return f'metrics/val/epoch/{s}_{metric}/{outcome}'
765+
766+
tile_metric = metrics[metric]['tile'][outcome]
767+
slide_metric = metrics[metric]['slide'][outcome]
768+
patient_metric = metrics[metric]['patient'][outcome]
769+
770+
# If only one value for a metric, log to .../[metric]
771+
# If more than one value for a metric (e.g. AUC for each
772+
# category), log to .../[metric]/[i]
773+
sf.util.neptune_utils.list_log(
774+
run,
775+
metric_label('tile'),
776+
tile_metric,
777+
step=epoch
778+
)
779+
sf.util.neptune_utils.list_log(
780+
run,
781+
metric_label('slide'),
782+
slide_metric,
783+
step=epoch
784+
)
785+
sf.util.neptune_utils.list_log(
786+
run,
787+
metric_label('patient'),
788+
patient_metric,
789+
step=epoch
790+
)
791+
666792
def _metrics_from_dataset(
667793
self,
668794
epoch_label: str,
@@ -728,18 +854,10 @@ def on_epoch_end(self, epoch: int, logs={}) -> None:
728854
self.model.stop_training = self.early_stop
729855

730856
def on_train_batch_end(self, batch: int, logs={}) -> None:
731-
# Neptune logging for training metrics
732-
if self.neptune_run:
733-
self.neptune_run['metrics/train/batch/loss'].log(
734-
logs['loss'],
735-
step=self.global_step
736-
)
737-
sf.util.neptune_utils.list_log(
738-
self.neptune_run,
739-
'metrics/train/batch/accuracy',
740-
logs['accuracy'],
741-
step=self.global_step
742-
)
857+
# Tensorboard logging for training metrics
858+
if batch > 0 and batch % self.cb_args.log_frequency == 0:
859+
#with self.train_summary_writer.as_default():
860+
self._log_training_metrics(logs)
743861

744862
# Check if manual early stopping has been triggered
745863
if (self.hp.early_stop
@@ -802,19 +920,10 @@ def on_train_batch_end(self, batch: int, logs={}) -> None:
802920
print('\r\033[K', end='')
803921
self.moving_average += [early_stop_value]
804922

805-
# Log to neptune
806-
if self.neptune_run:
807-
for v in val_metrics:
808-
self.neptune_run[f"metrics/val/batch/{v}"].log(
809-
round(val_metrics[v], 3),
810-
step=self.global_step
811-
)
812-
if self.last_ema != -1:
813-
self.neptune_run["metrics/val/batch/exp_moving_avg"].log(
814-
round(self.last_ema, 3),
815-
step=self.global_step
816-
)
817-
self.neptune_run["early_stop/stopped_early"] = False
923+
self._log_validation_metrics(logs)
924+
# Log training metrics if not already logged this batch
925+
if batch % self.cb_args.log_frequency > 0:
926+
self._log_training_metrics(logs)
818927

819928
# Base logging message
820929
batch_msg = f'[blue]Batch {batch:<5}[/]'
@@ -915,73 +1024,9 @@ def evaluate_model(self, logs={}) -> None:
9151024
self.results['epochs'][f'epoch{epoch}'][f'patient_{m}'] = metrics[m]['patient']
9161025

9171026
epoch_results = self.results['epochs'][f'epoch{epoch}']
918-
sf.util.update_results_log(
919-
self.cb_args.results_log,
920-
'trained_model',
921-
{f'epoch{epoch}': epoch_results}
1027+
self._log_epoch_evaluation(
1028+
epoch_results, metrics=metrics, accuracy=acc, loss=loss, logs=logs
9221029
)
923-
# Log epoch results to Neptune
924-
if self.neptune_run:
925-
# Training epoch metrics
926-
self.neptune_run['metrics/train/epoch/loss'].log(
927-
logs['loss'],
928-
step=epoch
929-
)
930-
sf.util.neptune_utils.list_log(
931-
self.neptune_run,
932-
'metrics/train/epoch/accuracy',
933-
logs['accuracy'],
934-
step=epoch
935-
)
936-
# Validation epoch metrics
937-
self.neptune_run['metrics/val/epoch/loss'].log(
938-
val_metrics['loss'],
939-
step=epoch
940-
)
941-
sf.util.neptune_utils.list_log(
942-
self.neptune_run,
943-
'metrics/val/epoch/accuracy',
944-
val_metrics['accuracy'],
945-
step=epoch
946-
)
947-
for metric in metrics:
948-
if metrics[metric]['tile'] is None:
949-
continue
950-
for outcome in metrics[metric]['tile']:
951-
# If only one outcome, log to metrics/val/epoch/[metric].
952-
# If more than one outcome, log to
953-
# metrics/val/epoch/[metric]/[outcome_name]
954-
def metric_label(s):
955-
if len(metrics[metric]['tile']) == 1:
956-
return f'metrics/val/epoch/{s}_{metric}'
957-
else:
958-
return f'metrics/val/epoch/{s}_{metric}/{outcome}'
959-
960-
tile_metric = metrics[metric]['tile'][outcome]
961-
slide_metric = metrics[metric]['slide'][outcome]
962-
patient_metric = metrics[metric]['patient'][outcome]
963-
964-
# If only one value for a metric, log to .../[metric]
965-
# If more than one value for a metric (e.g. AUC for each
966-
# category), log to .../[metric]/[i]
967-
sf.util.neptune_utils.list_log(
968-
self.neptune_run,
969-
metric_label('tile'),
970-
tile_metric,
971-
step=epoch
972-
)
973-
sf.util.neptune_utils.list_log(
974-
self.neptune_run,
975-
metric_label('slide'),
976-
slide_metric,
977-
step=epoch
978-
)
979-
sf.util.neptune_utils.list_log(
980-
self.neptune_run,
981-
metric_label('patient'),
982-
patient_metric,
983-
step=epoch
984-
)
9851030

9861031

9871032
class Trainer:
@@ -1791,7 +1836,8 @@ def train(
17911836
save_predictions=save_predictions,
17921837
save_model=save_model,
17931838
results_log=results_log,
1794-
reduce_method=reduce_method
1839+
reduce_method=reduce_method,
1840+
log_frequency=log_frequency
17951841
)
17961842

17971843
# Create callbacks for early stopping, checkpoint saving,
@@ -1806,11 +1852,15 @@ def train(
18061852
)
18071853
callbacks += [cp_callback]
18081854
if use_tensorboard:
1855+
log.debug(
1856+
"Logging with Tensorboard to {} every {} batches.".format(
1857+
self.outdir, log_frequency
1858+
))
18091859
tensorboard_callback = tf.keras.callbacks.TensorBoard(
18101860
log_dir=self.outdir,
18111861
histogram_freq=0,
18121862
write_graph=False,
1813-
update_freq=log_frequency
1863+
update_freq='batch'
18141864
)
18151865
callbacks += [tensorboard_callback]
18161866

0 commit comments

Comments
 (0)