Skip to content

Commit d9dc997

Browse files
committed
Tensorboard validation log fix
- Fix validation logs which were again improperly logged in Tensorboard with Tensorflow - Increase frequency of mid-validation checks during testing
1 parent 46b5687 commit d9dc997

File tree

2 files changed

+22
-38
lines changed

2 files changed

+22
-38
lines changed

slideflow/model/tensorflow.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,12 @@ def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
670670
def _log_training_metrics(self, logs):
671671
"""Log training metrics to Tensorboard/Neptune."""
672672
# Log to Tensorboard.
673-
for _log in logs:
674-
tf.summary.scalar(
675-
f'batch_{_log}',
676-
data=logs[_log],
677-
step=self.global_step)
673+
with self.train_summary_writer.as_default():
674+
for _log in logs:
675+
tf.summary.scalar(
676+
f'batch_{_log}',
677+
data=logs[_log],
678+
step=self.global_step)
678679
# Log to neptune.
679680
if self.neptune_run:
680681
self.neptune_run['metrics/train/batch/loss'].log(
@@ -890,10 +891,13 @@ def on_train_batch_end(self, batch: int, logs={}) -> None:
890891
verbosity='quiet',
891892
)
892893
val_metrics = {'loss': loss}
894+
val_log_metrics = {'loss': loss}
893895
if isinstance(acc, float):
894896
val_metrics['accuracy'] = acc
897+
val_log_metrics['accuracy'] = acc
895898
elif acc is not None:
896899
val_metrics.update({f'accuracy-{i+1}': acc[i] for i in range(len(acc))})
900+
val_log_metrics.update({f'out-{i}_accuracy': acc[i] for i in range(len(acc))})
897901

898902
val_loss = val_metrics['loss']
899903
self.model.stop_training = False
@@ -920,7 +924,7 @@ def on_train_batch_end(self, batch: int, logs={}) -> None:
920924
print('\r\033[K', end='')
921925
self.moving_average += [early_stop_value]
922926

923-
self._log_validation_metrics(logs)
927+
self._log_validation_metrics(val_log_metrics)
924928
# Log training metrics if not already logged this batch
925929
if batch % self.cb_args.log_frequency > 0:
926930
self._log_training_metrics(logs)
@@ -1356,7 +1360,7 @@ def _verify_img_format(self, dataset: "sf.Dataset") -> None:
13561360
def load(self, model: str) -> tf.keras.Model:
13571361
self.model = load(
13581362
model,
1359-
method=self.load_method,
1363+
method=self.load_method,
13601364
custom_objects=self.custom_objects
13611365
)
13621366

@@ -2503,7 +2507,7 @@ def _predict(self, inp):
25032507

25042508

25052509
def load(
2506-
path: str,
2510+
path: str,
25072511
method: str = 'full',
25082512
custom_objects: Optional[Dict[str, Any]] = None,):
25092513
"""Load Tensorflow model from location.

slideflow/test/__init__.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ def __init__(
107107
# Rebuild tfrecord indices
108108
self.project.dataset(self.tile_px, 1208).build_index(True)
109109

110+
# Set up training keyword arguments.
111+
self.train_kwargs = dict(
112+
validate_on_batch=5,
113+
steps_per_epoch_override=50,
114+
save_predictions=True
115+
)
116+
110117
def _get_model(self, name: str, epoch: int = 1) -> str:
111118
assert self.project is not None
112119
prev_run_dirs = [
@@ -326,9 +333,6 @@ def train_perf(self, **train_kwargs) -> None:
326333
exp_label='manual_hp',
327334
outcomes='category1',
328335
val_k=1,
329-
validate_on_batch=10,
330-
save_predictions=True,
331-
steps_per_epoch_override=20,
332336
params='sweep.json',
333337
pretrain=None,
334338
**train_kwargs
@@ -374,6 +378,9 @@ def test_training(
374378
additional slide-level input. Defaults to True.
375379
"""
376380
assert self.project is not None
381+
for k in self.train_kwargs:
382+
if k not in train_kwargs:
383+
train_kwargs[k] = self.train_kwargs[k]
377384
# Disable checkpoints for tensorflow backend, to save disk space
378385
if (sf.backend() == 'tensorflow'
379386
and 'save_checkpoints' not in train_kwargs):
@@ -408,9 +415,6 @@ def test_training(
408415
outcomes='category1',
409416
val_k=1,
410417
params=hp,
411-
validate_on_batch=10,
412-
steps_per_epoch_override=20,
413-
save_predictions=True,
414418
pretrain=None,
415419
**resume_kw,
416420
**train_kwargs
@@ -436,9 +440,6 @@ def test_training(
436440
outcomes='category1',
437441
val_k=1,
438442
params=hp,
439-
validate_on_batch=10,
440-
steps_per_epoch_override=20,
441-
save_predictions=True,
442443
pretrain=to_resume,
443444
**train_kwargs
444445
)
@@ -455,9 +456,6 @@ def test_training(
455456
outcomes=['category1', 'category2'],
456457
val_k=1,
457458
params=self.setup_hp('categorical'),
458-
validate_on_batch=10,
459-
steps_per_epoch_override=20,
460-
save_predictions=True,
461459
pretrain=None,
462460
**train_kwargs
463461
)
@@ -474,9 +472,6 @@ def test_training(
474472
outcomes=['linear1'],
475473
val_k=1,
476474
params=self.setup_hp('linear'),
477-
validate_on_batch=10,
478-
steps_per_epoch_override=20,
479-
save_predictions=True,
480475
pretrain=None,
481476
**train_kwargs
482477
)
@@ -493,9 +488,6 @@ def test_training(
493488
outcomes=['linear1', 'linear2'],
494489
val_k=1,
495490
params=self.setup_hp('linear'),
496-
validate_on_batch=10,
497-
steps_per_epoch_override=20,
498-
save_predictions=True,
499491
pretrain=None,
500492
**train_kwargs
501493
)
@@ -514,9 +506,6 @@ def test_training(
514506
input_header='category2',
515507
params=self.setup_hp('categorical'),
516508
val_k=1,
517-
validate_on_batch=10,
518-
steps_per_epoch_override=20,
519-
save_predictions=True,
520509
pretrain=None,
521510
**train_kwargs
522511
)
@@ -535,9 +524,6 @@ def test_training(
535524
input_header='event',
536525
params=self.setup_hp('cph'),
537526
val_k=1,
538-
validate_on_batch=10,
539-
steps_per_epoch_override=20,
540-
save_predictions=True,
541527
pretrain=None,
542528
**train_kwargs
543529
)
@@ -558,9 +544,6 @@ def test_training(
558544
input_header=['event', 'category1'],
559545
params=self.setup_hp('cph'),
560546
val_k=1,
561-
validate_on_batch=10,
562-
steps_per_epoch_override=20,
563-
save_predictions=True,
564547
pretrain=None,
565548
**train_kwargs
566549
)
@@ -581,9 +564,6 @@ def test_training(
581564
outcomes='category1',
582565
val_k=1,
583566
params=hp,
584-
validate_on_batch=10,
585-
steps_per_epoch_override=20,
586-
save_predictions=True,
587567
from_wsi=True,
588568
pretrain=None,
589569
**train_kwargs

0 commit comments

Comments
 (0)