Skip to content

Commit bfad4b9

Browse files
committed
Refactored another test case.
1 parent ff56ed6 commit bfad4b9

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

src/test/java/SimpleKNNClassifierTest.java

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,58 @@ the algorithm must predict the label (which its positive or negative) for the po
106106
classifier.train(Arrays.asList(pointA, pointB, pointC, pointD));
107107
List<Neighbor> similarNeighbors = classifier.similarNeighbors(pointE, 2);
108108

109-
Neighbor n1 = new Neighbor(new LabeledInstance(null, pointA.getModel()), 0d, null);
110-
Neighbor n2 = new Neighbor(new LabeledInstance(null, pointB.getModel()), 0d, null);
109+
Neighbor n1 = new Neighbor(new LabeledInstance(negativeLabel, pointA.getModel()), 0d, null);
110+
Neighbor n2 = new Neighbor(new LabeledInstance(negativeLabel, pointB.getModel()), 0d, null);
111111

112112
Truth.assertThat(similarNeighbors)
113113
.containsAllIn(Arrays.asList(n1, n2));
114114
}
115115

116+
@Test
117+
public void when_trained_with_k_fold_it_should_predict_a_positive_label(){
118+
/*
119+
given a set of negative points:
120+
- A(2,4); B(3,2); C(4,4)
121+
and a set of positive points:
122+
- D(4,1); E(5,5); F(6,3)
123+
the algorithm must predict the label (which its positive or negative) for the point G(10,7)
124+
*/
125+
126+
String positiveLabel = "positive";
127+
String negativeLabel = "negative";
128+
129+
LabeledInstance pointA = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(2d, 4d)));
130+
LabeledInstance pointB = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(3d, 2d)));
131+
LabeledInstance pointC = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(4d, 4d)));
132+
133+
LabeledInstance pointD = new LabeledInstance(positiveLabel, new TestModel(null, Arrays.asList(4d, 1d)));
134+
LabeledInstance pointE = new LabeledInstance(positiveLabel, new TestModel(null, Arrays.asList(5d, 5d)));
135+
LabeledInstance pointF = new LabeledInstance(positiveLabel, new TestModel(null, Arrays.asList(6d, 3d)));
136+
137+
classifier.setK(3);
138+
classifier.train(Arrays.asList(pointA, pointB, pointC, pointD, pointE, pointF), 3);
139+
140+
double scoreExpected = Math.sqrt(29)/100;
141+
Prediction predictedInstance = new Prediction(positiveLabel, scoreExpected);
142+
143+
Prediction predictedInstance1 = classifier.predict(pointF);
144+
Truth.assertThat(predictedInstance.getLabel())
145+
.isEqualTo(positiveLabel);
146+
Truth.assertThat(predictedInstance.getScore())
147+
.isEqualTo(scoreExpected);
148+
}
149+
150+
151+
116152
@Test(expected = IllegalArgumentException.class)
117153
public void when_similarNeighbors_its_called_with_null_neighbors_args_it_should_raise_an_exception(){
118154
classifier.similarNeighbors(null, 10);
119155
}
120156

157+
@Test(expected = IllegalArgumentException.class)
158+
public void it_should_raise_an_exception_when_predict_list_method_its_called(){
159+
classifier.predict(Collections.emptyList());
160+
}
121161
}
122162

123163

0 commit comments

Comments
 (0)