@@ -652,6 +652,10 @@ def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
652
652
self .results = {'epochs' : {}} # type: Dict[str, Dict]
653
653
self .neptune_run = self .parent .neptune_run
654
654
self .global_step = 0
655
+ self .train_summary_writer = tf .summary .create_file_writer (
656
+ join (self .parent .outdir , 'train' ))
657
+ self .val_summary_writer = tf .summary .create_file_writer (
658
+ join (self .parent .outdir , 'validation' ))
655
659
656
660
# Circumvents buffer overflow error with Python 3.10.
657
661
# Without this, a buffer overflow error will be encountered when
@@ -663,6 +667,128 @@ def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
663
667
plt .figure ()
664
668
plt .close ()
665
669
670
+ def _log_training_metrics (self , logs ):
671
+ """Log training metrics to Tensorboard/Neptune."""
672
+ # Log to Tensorboard.
673
+ for _log in logs :
674
+ tf .summary .scalar (
675
+ f'batch_{ _log } ' ,
676
+ data = logs [_log ],
677
+ step = self .global_step )
678
+ # Log to neptune.
679
+ if self .neptune_run :
680
+ self .neptune_run ['metrics/train/batch/loss' ].log (
681
+ logs ['loss' ],
682
+ step = self .global_step )
683
+ sf .util .neptune_utils .list_log (
684
+ self .neptune_run ,
685
+ 'metrics/train/batch/accuracy' ,
686
+ logs ['accuracy' ],
687
+ step = self .global_step )
688
+
689
+ def _log_validation_metrics (self , metrics ):
690
+ """Log validation metrics to Tensorboard/Neptune."""
691
+ # Tensorboard logging for validation metrics
692
+ with self .val_summary_writer .as_default ():
693
+ for _log in metrics :
694
+ tf .summary .scalar (
695
+ f'batch_{ _log } ' ,
696
+ data = metrics [_log ],
697
+ step = self .global_step )
698
+ # Log to neptune
699
+ if self .neptune_run :
700
+ for v in metrics :
701
+ self .neptune_run [f"metrics/val/batch/{ v } " ].log (
702
+ round (metrics [v ], 3 ),
703
+ step = self .global_step
704
+ )
705
+ if self .last_ema != - 1 :
706
+ self .neptune_run ["metrics/val/batch/exp_moving_avg" ].log (
707
+ round (self .last_ema , 3 ),
708
+ step = self .global_step
709
+ )
710
+ self .neptune_run ["early_stop/stopped_early" ] = False
711
+
712
+ def _log_epoch_evaluation (self , epoch_results , metrics , accuracy , loss , logs = {}):
713
+ """Log the end-of-epoch evaluation to CSV, Tensorboard, and Neptune."""
714
+ epoch = self .epoch_count
715
+ run = self .neptune_run
716
+ sf .util .update_results_log (
717
+ self .cb_args .results_log ,
718
+ 'trained_model' ,
719
+ {f'epoch{ epoch } ' : epoch_results }
720
+ )
721
+ with self .val_summary_writer .as_default ():
722
+ # Note: Tensorboard epoch logging starts with index=0,
723
+ # whereas all other logging starts with index=1
724
+ if isinstance (accuracy , (list , tuple , np .ndarray )):
725
+ for i in range (len (accuracy )):
726
+ tf .summary .scalar (f'epoch_accuracy-{ i } ' , data = accuracy [i ], step = epoch - 1 )
727
+ elif accuracy is not None :
728
+ tf .summary .scalar (f'epoch_accuracy' , data = accuracy , step = epoch - 1 )
729
+ if isinstance (loss , (list , tuple , np .ndarray )):
730
+ for i in range (len (loss )):
731
+ tf .summary .scalar (f'epoch_loss-{ i } ' , data = loss [i ], step = epoch - 1 )
732
+ else :
733
+ tf .summary .scalar (f'epoch_loss' , data = loss , step = epoch - 1 )
734
+
735
+ # Log epoch results to Neptune
736
+ if run :
737
+ # Training epoch metrics
738
+ run ['metrics/train/epoch/loss' ].log (logs ['loss' ], step = epoch )
739
+ sf .util .neptune_utils .list_log (
740
+ run ,
741
+ 'metrics/train/epoch/accuracy' ,
742
+ logs ['accuracy' ],
743
+ step = epoch
744
+ )
745
+ # Validation epoch metrics
746
+ run ['metrics/val/epoch/loss' ].log (loss , step = epoch )
747
+ sf .util .neptune_utils .list_log (
748
+ run ,
749
+ 'metrics/val/epoch/accuracy' ,
750
+ accuracy ,
751
+ step = epoch
752
+ )
753
+ for metric in metrics :
754
+ if metrics [metric ]['tile' ] is None :
755
+ continue
756
+ for outcome in metrics [metric ]['tile' ]:
757
+ # If only one outcome, log to metrics/val/epoch/[metric].
758
+ # If more than one outcome, log to
759
+ # metrics/val/epoch/[metric]/[outcome_name]
760
+ def metric_label (s ):
761
+ if len (metrics [metric ]['tile' ]) == 1 :
762
+ return f'metrics/val/epoch/{ s } _{ metric } '
763
+ else :
764
+ return f'metrics/val/epoch/{ s } _{ metric } /{ outcome } '
765
+
766
+ tile_metric = metrics [metric ]['tile' ][outcome ]
767
+ slide_metric = metrics [metric ]['slide' ][outcome ]
768
+ patient_metric = metrics [metric ]['patient' ][outcome ]
769
+
770
+ # If only one value for a metric, log to .../[metric]
771
+ # If more than one value for a metric (e.g. AUC for each
772
+ # category), log to .../[metric]/[i]
773
+ sf .util .neptune_utils .list_log (
774
+ run ,
775
+ metric_label ('tile' ),
776
+ tile_metric ,
777
+ step = epoch
778
+ )
779
+ sf .util .neptune_utils .list_log (
780
+ run ,
781
+ metric_label ('slide' ),
782
+ slide_metric ,
783
+ step = epoch
784
+ )
785
+ sf .util .neptune_utils .list_log (
786
+ run ,
787
+ metric_label ('patient' ),
788
+ patient_metric ,
789
+ step = epoch
790
+ )
791
+
666
792
def _metrics_from_dataset (
667
793
self ,
668
794
epoch_label : str ,
@@ -728,18 +854,10 @@ def on_epoch_end(self, epoch: int, logs={}) -> None:
728
854
self .model .stop_training = self .early_stop
729
855
730
856
def on_train_batch_end (self , batch : int , logs = {}) -> None :
731
- # Neptune logging for training metrics
732
- if self .neptune_run :
733
- self .neptune_run ['metrics/train/batch/loss' ].log (
734
- logs ['loss' ],
735
- step = self .global_step
736
- )
737
- sf .util .neptune_utils .list_log (
738
- self .neptune_run ,
739
- 'metrics/train/batch/accuracy' ,
740
- logs ['accuracy' ],
741
- step = self .global_step
742
- )
857
+ # Tensorboard logging for training metrics
858
+ if batch > 0 and batch % self .cb_args .log_frequency == 0 :
859
+ #with self.train_summary_writer.as_default():
860
+ self ._log_training_metrics (logs )
743
861
744
862
# Check if manual early stopping has been triggered
745
863
if (self .hp .early_stop
@@ -802,19 +920,10 @@ def on_train_batch_end(self, batch: int, logs={}) -> None:
802
920
print ('\r \033 [K' , end = '' )
803
921
self .moving_average += [early_stop_value ]
804
922
805
- # Log to neptune
806
- if self .neptune_run :
807
- for v in val_metrics :
808
- self .neptune_run [f"metrics/val/batch/{ v } " ].log (
809
- round (val_metrics [v ], 3 ),
810
- step = self .global_step
811
- )
812
- if self .last_ema != - 1 :
813
- self .neptune_run ["metrics/val/batch/exp_moving_avg" ].log (
814
- round (self .last_ema , 3 ),
815
- step = self .global_step
816
- )
817
- self .neptune_run ["early_stop/stopped_early" ] = False
923
+ self ._log_validation_metrics (logs )
924
+ # Log training metrics if not already logged this batch
925
+ if batch % self .cb_args .log_frequency > 0 :
926
+ self ._log_training_metrics (logs )
818
927
819
928
# Base logging message
820
929
batch_msg = f'[blue]Batch { batch :<5} [/]'
@@ -915,73 +1024,9 @@ def evaluate_model(self, logs={}) -> None:
915
1024
self .results ['epochs' ][f'epoch{ epoch } ' ][f'patient_{ m } ' ] = metrics [m ]['patient' ]
916
1025
917
1026
epoch_results = self .results ['epochs' ][f'epoch{ epoch } ' ]
918
- sf .util .update_results_log (
919
- self .cb_args .results_log ,
920
- 'trained_model' ,
921
- {f'epoch{ epoch } ' : epoch_results }
1027
+ self ._log_epoch_evaluation (
1028
+ epoch_results , metrics = metrics , accuracy = acc , loss = loss , logs = logs
922
1029
)
923
- # Log epoch results to Neptune
924
- if self .neptune_run :
925
- # Training epoch metrics
926
- self .neptune_run ['metrics/train/epoch/loss' ].log (
927
- logs ['loss' ],
928
- step = epoch
929
- )
930
- sf .util .neptune_utils .list_log (
931
- self .neptune_run ,
932
- 'metrics/train/epoch/accuracy' ,
933
- logs ['accuracy' ],
934
- step = epoch
935
- )
936
- # Validation epoch metrics
937
- self .neptune_run ['metrics/val/epoch/loss' ].log (
938
- val_metrics ['loss' ],
939
- step = epoch
940
- )
941
- sf .util .neptune_utils .list_log (
942
- self .neptune_run ,
943
- 'metrics/val/epoch/accuracy' ,
944
- val_metrics ['accuracy' ],
945
- step = epoch
946
- )
947
- for metric in metrics :
948
- if metrics [metric ]['tile' ] is None :
949
- continue
950
- for outcome in metrics [metric ]['tile' ]:
951
- # If only one outcome, log to metrics/val/epoch/[metric].
952
- # If more than one outcome, log to
953
- # metrics/val/epoch/[metric]/[outcome_name]
954
- def metric_label (s ):
955
- if len (metrics [metric ]['tile' ]) == 1 :
956
- return f'metrics/val/epoch/{ s } _{ metric } '
957
- else :
958
- return f'metrics/val/epoch/{ s } _{ metric } /{ outcome } '
959
-
960
- tile_metric = metrics [metric ]['tile' ][outcome ]
961
- slide_metric = metrics [metric ]['slide' ][outcome ]
962
- patient_metric = metrics [metric ]['patient' ][outcome ]
963
-
964
- # If only one value for a metric, log to .../[metric]
965
- # If more than one value for a metric (e.g. AUC for each
966
- # category), log to .../[metric]/[i]
967
- sf .util .neptune_utils .list_log (
968
- self .neptune_run ,
969
- metric_label ('tile' ),
970
- tile_metric ,
971
- step = epoch
972
- )
973
- sf .util .neptune_utils .list_log (
974
- self .neptune_run ,
975
- metric_label ('slide' ),
976
- slide_metric ,
977
- step = epoch
978
- )
979
- sf .util .neptune_utils .list_log (
980
- self .neptune_run ,
981
- metric_label ('patient' ),
982
- patient_metric ,
983
- step = epoch
984
- )
985
1030
986
1031
987
1032
class Trainer :
@@ -1791,7 +1836,8 @@ def train(
1791
1836
save_predictions = save_predictions ,
1792
1837
save_model = save_model ,
1793
1838
results_log = results_log ,
1794
- reduce_method = reduce_method
1839
+ reduce_method = reduce_method ,
1840
+ log_frequency = log_frequency
1795
1841
)
1796
1842
1797
1843
# Create callbacks for early stopping, checkpoint saving,
@@ -1806,11 +1852,15 @@ def train(
1806
1852
)
1807
1853
callbacks += [cp_callback ]
1808
1854
if use_tensorboard :
1855
+ log .debug (
1856
+ "Logging with Tensorboard to {} every {} batches." .format (
1857
+ self .outdir , log_frequency
1858
+ ))
1809
1859
tensorboard_callback = tf .keras .callbacks .TensorBoard (
1810
1860
log_dir = self .outdir ,
1811
1861
histogram_freq = 0 ,
1812
1862
write_graph = False ,
1813
- update_freq = log_frequency
1863
+ update_freq = 'batch'
1814
1864
)
1815
1865
callbacks += [tensorboard_callback ]
1816
1866
0 commit comments