Skip to content

Commit ab21b33

Browse files
Merge pull request #16 from RonasIT/refactoring
Refactoring
2 parents 6eb3aa6 + c3fd638 commit ab21b33

File tree

4 files changed

+84
-75
lines changed

4 files changed

+84
-75
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,15 @@ then you can create the datasets for training, validation and testing manually,
115115
and pass them as the `trainingDataset`, `validationDataset` and `testingDataset`
116116
parameters.
117117

118-
You can also print the testing results by setting the `printResults` to `true`.
118+
You can also print the testing results by setting the `printTestingResults` to
119+
`true`.
119120

120121
An example can be found below:
121122

122123
```typescript
123124
await trainer.trainAndTest({
124125
data,
125-
printResults: true
126+
printTestingResults: true
126127
});
127128
```
128129

packages/tfjs-node-helpers-example/src/app/app.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@ import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car';
77
import { join } from 'node:path';
88
import data from '../assets/data.json';
99

10-
const EPOCHS = 1000;
11-
const PATIENCE = 20;
12-
const BATCH_SIZE = 32;
13-
14-
export async function startApplication() {
10+
export async function startApplication(): Promise<void> {
1511
await train();
1612
await predict();
1713
}
1814

1915
async function train(): Promise<void> {
2016
const trainer = new BinaryClassificationTrainer({
21-
batchSize: BATCH_SIZE,
22-
epochs: EPOCHS,
23-
patience: PATIENCE,
2417
hiddenLayers: [
2518
layers.dense({ units: 128, activation: 'mish' }),
2619
layers.dense({ units: 128, activation: 'mish' })
@@ -35,7 +28,7 @@ async function train(): Promise<void> {
3528

3629
await trainer.trainAndTest({
3730
data,
38-
printResults: true
31+
printTestingResults: true
3932
});
4033

4134
await trainer.save(join(__dirname, './trained_model'));

packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,26 @@ import {
66
LayersModel,
77
model,
88
node,
9-
onesLike,
109
Optimizer,
1110
Scalar,
1211
SymbolicTensor,
1312
Tensor,
14-
TensorContainer,
15-
tidy,
16-
where,
17-
zerosLike
13+
TensorContainer
1814
} from '@tensorflow/tfjs-node';
1915
import { green, red } from 'chalk';
2016
import { Table } from 'console-table-printer';
2117
import { FeatureExtractor } from '../feature-engineering/feature-extractor';
2218
import { prepareDatasetsForBinaryClassification } from '../feature-engineering/prepare-datasets-for-binary-classification';
2319
import { ConfusionMatrix } from '../testing/confusion-matrix';
2420
import { Metrics } from '../testing/metrics';
21+
import { binarize } from '../utils/binarize';
2522

2623
export type BinaryClassificationTrainerOptions = {
27-
batchSize: number;
28-
epochs: number;
29-
patience: number;
30-
inputFeatureExtractors: Array<FeatureExtractor<any, any>>;
31-
outputFeatureExtractor: FeatureExtractor<any, any>;
24+
batchSize?: number;
25+
epochs?: number;
26+
patience?: number;
27+
inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
28+
outputFeatureExtractor?: FeatureExtractor<any, any>;
3229
model?: LayersModel;
3330
hiddenLayers?: Array<layers.Layer>;
3431
optimizer?: string | Optimizer;
@@ -40,60 +37,37 @@ export class BinaryClassificationTrainer {
4037
protected epochs: number;
4138
protected patience: number;
4239
protected tensorBoardLogsDirectory?: string;
43-
protected inputFeatureExtractors: Array<FeatureExtractor<any, any>>;
44-
protected outputFeatureExtractor: FeatureExtractor<any, any>;
40+
protected inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
41+
protected outputFeatureExtractor?: FeatureExtractor<any, any>;
4542
protected model!: LayersModel;
4643

44+
protected static DEFAULT_BATCH_SIZE: number = 32;
45+
protected static DEFAULT_EPOCHS: number = 1000;
46+
protected static DEFAULT_PATIENCE: number = 20;
47+
4748
constructor(options: BinaryClassificationTrainerOptions) {
48-
this.batchSize = options.batchSize;
49-
this.epochs = options.epochs;
50-
this.patience = options.patience;
49+
this.batchSize = options.batchSize ?? BinaryClassificationTrainer.DEFAULT_BATCH_SIZE;
50+
this.epochs = options.epochs ?? BinaryClassificationTrainer.DEFAULT_EPOCHS;
51+
this.patience = options.patience ?? BinaryClassificationTrainer.DEFAULT_PATIENCE;
5152
this.tensorBoardLogsDirectory = options.tensorBoardLogsDirectory;
5253
this.inputFeatureExtractors = options.inputFeatureExtractors;
5354
this.outputFeatureExtractor = options.outputFeatureExtractor;
5455

55-
if (options.model !== undefined) {
56-
this.model = options.model;
57-
} else {
58-
if (options.hiddenLayers !== undefined && options.inputFeatureExtractors !== undefined) {
59-
const inputLayer = input({ shape: [options.inputFeatureExtractors.length] });
60-
let symbolicTensor = inputLayer;
61-
62-
options.hiddenLayers.forEach((layer) => {
63-
symbolicTensor = layer.apply(symbolicTensor) as SymbolicTensor;
64-
});
65-
66-
const outputLayer = layers
67-
.dense({ units: 1, activation: 'sigmoid' })
68-
.apply(symbolicTensor) as SymbolicTensor;
69-
70-
this.model = model({
71-
inputs: inputLayer,
72-
outputs: outputLayer
73-
});
74-
} else {
75-
throw new Error('hiddenLayers and inputFeaturesCount options are required when the model is not provided!');
76-
}
77-
}
78-
79-
this.model.compile({
80-
optimizer: options.optimizer ?? 'adam',
81-
loss: 'binaryCrossentropy'
82-
});
56+
this.initializeModel(options);
8357
}
8458

8559
public async trainAndTest({
8660
data,
8761
trainingDataset,
8862
validationDataset,
8963
testingDataset,
90-
printResults
64+
printTestingResults
9165
}: {
9266
data?: Array<any>,
9367
trainingDataset?: data.Dataset<TensorContainer>;
9468
validationDataset?: data.Dataset<TensorContainer>;
9569
testingDataset?: data.Dataset<TensorContainer>;
96-
printResults?: boolean;
70+
printTestingResults?: boolean;
9771
}): Promise<{
9872
loss: number;
9973
confusionMatrix: ConfusionMatrix;
@@ -111,7 +85,15 @@ export class BinaryClassificationTrainer {
11185
callbacks.push(node.tensorBoard(this.tensorBoardLogsDirectory));
11286
}
11387

114-
if (trainingDataset === undefined || validationDataset === undefined || testingDataset === undefined) {
88+
if (
89+
trainingDataset === undefined ||
90+
validationDataset === undefined ||
91+
testingDataset === undefined
92+
) {
93+
if (this.inputFeatureExtractors === undefined || this.outputFeatureExtractor === undefined) {
94+
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!');
95+
}
96+
11597
const datasets = await prepareDatasetsForBinaryClassification({
11698
data: data as Array<any>,
11799
inputFeatureExtractors: this.inputFeatureExtractors,
@@ -130,19 +112,50 @@ export class BinaryClassificationTrainer {
130112
callbacks
131113
});
132114

133-
return await this.test({ testingDataset, printResults });
115+
return await this.test({ testingDataset, printTestingResults });
134116
}
135117

136118
public async save(path: string): Promise<void> {
137119
await this.model.save(`file://${path}`);
138120
}
139121

122+
private initializeModel(options: BinaryClassificationTrainerOptions): void {
123+
if (options.model !== undefined) {
124+
this.model = options.model;
125+
} else {
126+
if (options.hiddenLayers !== undefined && options.inputFeatureExtractors !== undefined) {
127+
const inputLayer = input({ shape: [options.inputFeatureExtractors.length] });
128+
let symbolicTensor = inputLayer;
129+
130+
for (const layer of options.hiddenLayers) {
131+
symbolicTensor = layer.apply(symbolicTensor) as SymbolicTensor;
132+
}
133+
134+
const outputLayer = layers
135+
.dense({ units: 1, activation: 'sigmoid' })
136+
.apply(symbolicTensor) as SymbolicTensor;
137+
138+
this.model = model({
139+
inputs: inputLayer,
140+
outputs: outputLayer
141+
});
142+
} else {
143+
throw new Error('hiddenLayers and inputFeatureExtractors options are required when the model is not provided!');
144+
}
145+
}
146+
147+
this.model.compile({
148+
optimizer: options.optimizer ?? 'adam',
149+
loss: 'binaryCrossentropy'
150+
});
151+
}
152+
140153
private async test({
141154
testingDataset,
142-
printResults
155+
printTestingResults
143156
}: {
144157
testingDataset: data.Dataset<TensorContainer>;
145-
printResults?: boolean;
158+
printTestingResults?: boolean;
146159
}): Promise<{
147160
loss: number;
148161
confusionMatrix: ConfusionMatrix;
@@ -151,23 +164,24 @@ export class BinaryClassificationTrainer {
151164
const lossTensor = (await this.model.evaluateDataset(testingDataset as data.Dataset<any>, {})) as Scalar;
152165
const [loss] = await lossTensor.data();
153166

154-
const testingData = (await testingDataset.toArray()) as Array<{
167+
const [testingData] = (await testingDataset.toArray()) as Array<{
155168
xs: Tensor;
156169
ys: Tensor;
157170
}>;
158-
const testXs = testingData[0].xs;
159-
const testYs = testingData[0].ys;
171+
172+
const testXs = testingData.xs;
173+
const testYs = testingData.ys;
160174

161175
const predictions = this.model.predict(testXs) as Tensor;
162-
const binarizedPredictions = this.binarize(predictions);
176+
const binarizedPredictions = binarize(predictions);
163177

164-
const trueValues = (await testYs.data()) as Float32Array;
165-
const predictedValues = (await binarizedPredictions.data()) as Float32Array;
178+
const trueValues = await testYs.data<'float32'>();
179+
const predictedValues = await binarizedPredictions.data<'float32'>();
166180

167181
const confusionMatrix = this.calculateConfusionMatrix(trueValues, predictedValues);
168182
const metrics = this.calculateMetrics(confusionMatrix);
169183

170-
if (printResults) {
184+
if (printTestingResults) {
171185
this.printTestResults(loss, confusionMatrix, metrics);
172186
}
173187

@@ -301,12 +315,4 @@ export class BinaryClassificationTrainer {
301315

302316
metricsTable.printTable();
303317
}
304-
305-
private binarize(tensor: Tensor, threshold = 0.5): Tensor {
306-
return tidy(() => {
307-
const condition = tensor.greater(threshold);
308-
309-
return where(condition, onesLike(tensor), zerosLike(tensor));
310-
});
311-
}
312318
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { onesLike, Tensor, tidy, where, zerosLike } from '@tensorflow/tfjs-node';
2+
3+
export const binarize = (tensor: Tensor, threshold = 0.5): Tensor => tidy(
4+
() => where(
5+
tensor.greater(threshold),
6+
onesLike(tensor),
7+
zerosLike(tensor)
8+
)
9+
);

0 commit comments

Comments
 (0)