@@ -107,6 +107,13 @@ def __init__(
107
107
# Rebuild tfrecord indices
108
108
self .project .dataset (self .tile_px , 1208 ).build_index (True )
109
109
110
+ # Set up training keyword arguments.
111
+ self .train_kwargs = dict (
112
+ validate_on_batch = 5 ,
113
+ steps_per_epoch_override = 50 ,
114
+ save_predictions = True
115
+ )
116
+
110
117
def _get_model (self , name : str , epoch : int = 1 ) -> str :
111
118
assert self .project is not None
112
119
prev_run_dirs = [
@@ -326,9 +333,6 @@ def train_perf(self, **train_kwargs) -> None:
326
333
exp_label = 'manual_hp' ,
327
334
outcomes = 'category1' ,
328
335
val_k = 1 ,
329
- validate_on_batch = 10 ,
330
- save_predictions = True ,
331
- steps_per_epoch_override = 20 ,
332
336
params = 'sweep.json' ,
333
337
pretrain = None ,
334
338
** train_kwargs
@@ -374,6 +378,9 @@ def test_training(
374
378
additional slide-level input. Defaults to True.
375
379
"""
376
380
assert self .project is not None
381
+ for k in self .train_kwargs :
382
+ if k not in train_kwargs :
383
+ train_kwargs [k ] = self .train_kwargs [k ]
377
384
# Disable checkpoints for tensorflow backend, to save disk space
378
385
if (sf .backend () == 'tensorflow'
379
386
and 'save_checkpoints' not in train_kwargs ):
@@ -408,9 +415,6 @@ def test_training(
408
415
outcomes = 'category1' ,
409
416
val_k = 1 ,
410
417
params = hp ,
411
- validate_on_batch = 10 ,
412
- steps_per_epoch_override = 20 ,
413
- save_predictions = True ,
414
418
pretrain = None ,
415
419
** resume_kw ,
416
420
** train_kwargs
@@ -436,9 +440,6 @@ def test_training(
436
440
outcomes = 'category1' ,
437
441
val_k = 1 ,
438
442
params = hp ,
439
- validate_on_batch = 10 ,
440
- steps_per_epoch_override = 20 ,
441
- save_predictions = True ,
442
443
pretrain = to_resume ,
443
444
** train_kwargs
444
445
)
@@ -455,9 +456,6 @@ def test_training(
455
456
outcomes = ['category1' , 'category2' ],
456
457
val_k = 1 ,
457
458
params = self .setup_hp ('categorical' ),
458
- validate_on_batch = 10 ,
459
- steps_per_epoch_override = 20 ,
460
- save_predictions = True ,
461
459
pretrain = None ,
462
460
** train_kwargs
463
461
)
@@ -474,9 +472,6 @@ def test_training(
474
472
outcomes = ['linear1' ],
475
473
val_k = 1 ,
476
474
params = self .setup_hp ('linear' ),
477
- validate_on_batch = 10 ,
478
- steps_per_epoch_override = 20 ,
479
- save_predictions = True ,
480
475
pretrain = None ,
481
476
** train_kwargs
482
477
)
@@ -493,9 +488,6 @@ def test_training(
493
488
outcomes = ['linear1' , 'linear2' ],
494
489
val_k = 1 ,
495
490
params = self .setup_hp ('linear' ),
496
- validate_on_batch = 10 ,
497
- steps_per_epoch_override = 20 ,
498
- save_predictions = True ,
499
491
pretrain = None ,
500
492
** train_kwargs
501
493
)
@@ -514,9 +506,6 @@ def test_training(
514
506
input_header = 'category2' ,
515
507
params = self .setup_hp ('categorical' ),
516
508
val_k = 1 ,
517
- validate_on_batch = 10 ,
518
- steps_per_epoch_override = 20 ,
519
- save_predictions = True ,
520
509
pretrain = None ,
521
510
** train_kwargs
522
511
)
@@ -535,9 +524,6 @@ def test_training(
535
524
input_header = 'event' ,
536
525
params = self .setup_hp ('cph' ),
537
526
val_k = 1 ,
538
- validate_on_batch = 10 ,
539
- steps_per_epoch_override = 20 ,
540
- save_predictions = True ,
541
527
pretrain = None ,
542
528
** train_kwargs
543
529
)
@@ -558,9 +544,6 @@ def test_training(
558
544
input_header = ['event' , 'category1' ],
559
545
params = self .setup_hp ('cph' ),
560
546
val_k = 1 ,
561
- validate_on_batch = 10 ,
562
- steps_per_epoch_override = 20 ,
563
- save_predictions = True ,
564
547
pretrain = None ,
565
548
** train_kwargs
566
549
)
@@ -581,9 +564,6 @@ def test_training(
581
564
outcomes = 'category1' ,
582
565
val_k = 1 ,
583
566
params = hp ,
584
- validate_on_batch = 10 ,
585
- steps_per_epoch_override = 20 ,
586
- save_predictions = True ,
587
567
from_wsi = True ,
588
568
pretrain = None ,
589
569
** train_kwargs
0 commit comments