@@ -6,29 +6,26 @@ import {
6
6
LayersModel ,
7
7
model ,
8
8
node ,
9
- onesLike ,
10
9
Optimizer ,
11
10
Scalar ,
12
11
SymbolicTensor ,
13
12
Tensor ,
14
- TensorContainer ,
15
- tidy ,
16
- where ,
17
- zerosLike
13
+ TensorContainer
18
14
} from '@tensorflow/tfjs-node' ;
19
15
import { green , red } from 'chalk' ;
20
16
import { Table } from 'console-table-printer' ;
21
17
import { FeatureExtractor } from '../feature-engineering/feature-extractor' ;
22
18
import { prepareDatasetsForBinaryClassification } from '../feature-engineering/prepare-datasets-for-binary-classification' ;
23
19
import { ConfusionMatrix } from '../testing/confusion-matrix' ;
24
20
import { Metrics } from '../testing/metrics' ;
21
+ import { binarize } from '../utils/binarize' ;
25
22
26
23
export type BinaryClassificationTrainerOptions = {
27
- batchSize : number ;
28
- epochs : number ;
29
- patience : number ;
30
- inputFeatureExtractors : Array < FeatureExtractor < any , any > > ;
31
- outputFeatureExtractor : FeatureExtractor < any , any > ;
24
+ batchSize ? : number ;
25
+ epochs ? : number ;
26
+ patience ? : number ;
27
+ inputFeatureExtractors ? : Array < FeatureExtractor < any , any > > ;
28
+ outputFeatureExtractor ? : FeatureExtractor < any , any > ;
32
29
model ?: LayersModel ;
33
30
hiddenLayers ?: Array < layers . Layer > ;
34
31
optimizer ?: string | Optimizer ;
@@ -40,60 +37,37 @@ export class BinaryClassificationTrainer {
40
37
protected epochs : number ;
41
38
protected patience : number ;
42
39
protected tensorBoardLogsDirectory ?: string ;
43
- protected inputFeatureExtractors : Array < FeatureExtractor < any , any > > ;
44
- protected outputFeatureExtractor : FeatureExtractor < any , any > ;
40
+ protected inputFeatureExtractors ? : Array < FeatureExtractor < any , any > > ;
41
+ protected outputFeatureExtractor ? : FeatureExtractor < any , any > ;
45
42
protected model ! : LayersModel ;
46
43
44
+ protected static DEFAULT_BATCH_SIZE : number = 32 ;
45
+ protected static DEFAULT_EPOCHS : number = 1000 ;
46
+ protected static DEFAULT_PATIENCE : number = 20 ;
47
+
47
48
constructor ( options : BinaryClassificationTrainerOptions ) {
48
- this . batchSize = options . batchSize ;
49
- this . epochs = options . epochs ;
50
- this . patience = options . patience ;
49
+ this . batchSize = options . batchSize ?? BinaryClassificationTrainer . DEFAULT_BATCH_SIZE ;
50
+ this . epochs = options . epochs ?? BinaryClassificationTrainer . DEFAULT_EPOCHS ;
51
+ this . patience = options . patience ?? BinaryClassificationTrainer . DEFAULT_PATIENCE ;
51
52
this . tensorBoardLogsDirectory = options . tensorBoardLogsDirectory ;
52
53
this . inputFeatureExtractors = options . inputFeatureExtractors ;
53
54
this . outputFeatureExtractor = options . outputFeatureExtractor ;
54
55
55
- if ( options . model !== undefined ) {
56
- this . model = options . model ;
57
- } else {
58
- if ( options . hiddenLayers !== undefined && options . inputFeatureExtractors !== undefined ) {
59
- const inputLayer = input ( { shape : [ options . inputFeatureExtractors . length ] } ) ;
60
- let symbolicTensor = inputLayer ;
61
-
62
- options . hiddenLayers . forEach ( ( layer ) => {
63
- symbolicTensor = layer . apply ( symbolicTensor ) as SymbolicTensor ;
64
- } ) ;
65
-
66
- const outputLayer = layers
67
- . dense ( { units : 1 , activation : 'sigmoid' } )
68
- . apply ( symbolicTensor ) as SymbolicTensor ;
69
-
70
- this . model = model ( {
71
- inputs : inputLayer ,
72
- outputs : outputLayer
73
- } ) ;
74
- } else {
75
- throw new Error ( 'hiddenLayers and inputFeaturesCount options are required when the model is not provided!' ) ;
76
- }
77
- }
78
-
79
- this . model . compile ( {
80
- optimizer : options . optimizer ?? 'adam' ,
81
- loss : 'binaryCrossentropy'
82
- } ) ;
56
+ this . initializeModel ( options ) ;
83
57
}
84
58
85
59
public async trainAndTest ( {
86
60
data,
87
61
trainingDataset,
88
62
validationDataset,
89
63
testingDataset,
90
- printResults
64
+ printTestingResults
91
65
} : {
92
66
data ?: Array < any > ,
93
67
trainingDataset ?: data . Dataset < TensorContainer > ;
94
68
validationDataset ?: data . Dataset < TensorContainer > ;
95
69
testingDataset ?: data . Dataset < TensorContainer > ;
96
- printResults ?: boolean ;
70
+ printTestingResults ?: boolean ;
97
71
} ) : Promise < {
98
72
loss : number ;
99
73
confusionMatrix : ConfusionMatrix ;
@@ -111,7 +85,15 @@ export class BinaryClassificationTrainer {
111
85
callbacks . push ( node . tensorBoard ( this . tensorBoardLogsDirectory ) ) ;
112
86
}
113
87
114
- if ( trainingDataset === undefined || validationDataset === undefined || testingDataset === undefined ) {
88
+ if (
89
+ trainingDataset === undefined ||
90
+ validationDataset === undefined ||
91
+ testingDataset === undefined
92
+ ) {
93
+ if ( this . inputFeatureExtractors === undefined || this . outputFeatureExtractor === undefined ) {
94
+ throw new Error ( 'trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!' ) ;
95
+ }
96
+
115
97
const datasets = await prepareDatasetsForBinaryClassification ( {
116
98
data : data as Array < any > ,
117
99
inputFeatureExtractors : this . inputFeatureExtractors ,
@@ -130,19 +112,50 @@ export class BinaryClassificationTrainer {
130
112
callbacks
131
113
} ) ;
132
114
133
- return await this . test ( { testingDataset, printResults } ) ;
115
+ return await this . test ( { testingDataset, printTestingResults } ) ;
134
116
}
135
117
136
118
public async save ( path : string ) : Promise < void > {
137
119
await this . model . save ( `file://${ path } ` ) ;
138
120
}
139
121
122
+ private initializeModel ( options : BinaryClassificationTrainerOptions ) : void {
123
+ if ( options . model !== undefined ) {
124
+ this . model = options . model ;
125
+ } else {
126
+ if ( options . hiddenLayers !== undefined && options . inputFeatureExtractors !== undefined ) {
127
+ const inputLayer = input ( { shape : [ options . inputFeatureExtractors . length ] } ) ;
128
+ let symbolicTensor = inputLayer ;
129
+
130
+ for ( const layer of options . hiddenLayers ) {
131
+ symbolicTensor = layer . apply ( symbolicTensor ) as SymbolicTensor ;
132
+ }
133
+
134
+ const outputLayer = layers
135
+ . dense ( { units : 1 , activation : 'sigmoid' } )
136
+ . apply ( symbolicTensor ) as SymbolicTensor ;
137
+
138
+ this . model = model ( {
139
+ inputs : inputLayer ,
140
+ outputs : outputLayer
141
+ } ) ;
142
+ } else {
143
+ throw new Error ( 'hiddenLayers and inputFeatureExtractors options are required when the model is not provided!' ) ;
144
+ }
145
+ }
146
+
147
+ this . model . compile ( {
148
+ optimizer : options . optimizer ?? 'adam' ,
149
+ loss : 'binaryCrossentropy'
150
+ } ) ;
151
+ }
152
+
140
153
private async test ( {
141
154
testingDataset,
142
- printResults
155
+ printTestingResults
143
156
} : {
144
157
testingDataset : data . Dataset < TensorContainer > ;
145
- printResults ?: boolean ;
158
+ printTestingResults ?: boolean ;
146
159
} ) : Promise < {
147
160
loss : number ;
148
161
confusionMatrix : ConfusionMatrix ;
@@ -151,23 +164,24 @@ export class BinaryClassificationTrainer {
151
164
const lossTensor = ( await this . model . evaluateDataset ( testingDataset as data . Dataset < any > , { } ) ) as Scalar ;
152
165
const [ loss ] = await lossTensor . data ( ) ;
153
166
154
- const testingData = ( await testingDataset . toArray ( ) ) as Array < {
167
+ const [ testingData ] = ( await testingDataset . toArray ( ) ) as Array < {
155
168
xs : Tensor ;
156
169
ys : Tensor ;
157
170
} > ;
158
- const testXs = testingData [ 0 ] . xs ;
159
- const testYs = testingData [ 0 ] . ys ;
171
+
172
+ const testXs = testingData . xs ;
173
+ const testYs = testingData . ys ;
160
174
161
175
const predictions = this . model . predict ( testXs ) as Tensor ;
162
- const binarizedPredictions = this . binarize ( predictions ) ;
176
+ const binarizedPredictions = binarize ( predictions ) ;
163
177
164
- const trueValues = ( await testYs . data ( ) ) as Float32Array ;
165
- const predictedValues = ( await binarizedPredictions . data ( ) ) as Float32Array ;
178
+ const trueValues = await testYs . data < 'float32' > ( ) ;
179
+ const predictedValues = await binarizedPredictions . data < 'float32' > ( ) ;
166
180
167
181
const confusionMatrix = this . calculateConfusionMatrix ( trueValues , predictedValues ) ;
168
182
const metrics = this . calculateMetrics ( confusionMatrix ) ;
169
183
170
- if ( printResults ) {
184
+ if ( printTestingResults ) {
171
185
this . printTestResults ( loss , confusionMatrix , metrics ) ;
172
186
}
173
187
@@ -301,12 +315,4 @@ export class BinaryClassificationTrainer {
301
315
302
316
metricsTable . printTable ( ) ;
303
317
}
304
-
305
- private binarize ( tensor : Tensor , threshold = 0.5 ) : Tensor {
306
- return tidy ( ( ) => {
307
- const condition = tensor . greater ( threshold ) ;
308
-
309
- return where ( condition , onesLike ( tensor ) , zerosLike ( tensor ) ) ;
310
- } ) ;
311
- }
312
318
}
0 commit comments