@@ -106,18 +106,58 @@ the algorithm must predict the label (which its positive or negative) for the po
106
106
classifier .train (Arrays .asList (pointA , pointB , pointC , pointD ));
107
107
List <Neighbor > similarNeighbors = classifier .similarNeighbors (pointE , 2 );
108
108
109
- Neighbor n1 = new Neighbor (new LabeledInstance (null , pointA .getModel ()), 0d , null );
110
- Neighbor n2 = new Neighbor (new LabeledInstance (null , pointB .getModel ()), 0d , null );
109
+ Neighbor n1 = new Neighbor (new LabeledInstance (negativeLabel , pointA .getModel ()), 0d , null );
110
+ Neighbor n2 = new Neighbor (new LabeledInstance (negativeLabel , pointB .getModel ()), 0d , null );
111
111
112
112
Truth .assertThat (similarNeighbors )
113
113
.containsAllIn (Arrays .asList (n1 , n2 ));
114
114
}
115
115
116
+ @ Test
117
+ public void when_trained_with_k_fold_it_should_predict_a_positive_label (){
118
+ /*
119
+ given a set of negative points:
120
+ - A(2,4); B(3,2); C(4,4)
121
+ and a set of positive points:
122
+ - D(4,1); E(5,5); F(6,3)
123
+ the algorithm must predict the label (which its positive or negative) for the point G(10,7)
124
+ */
125
+
126
+ String positiveLabel = "positive" ;
127
+ String negativeLabel = "negative" ;
128
+
129
+ LabeledInstance pointA = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (2d , 4d )));
130
+ LabeledInstance pointB = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (3d , 2d )));
131
+ LabeledInstance pointC = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (4d , 4d )));
132
+
133
+ LabeledInstance pointD = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (4d , 1d )));
134
+ LabeledInstance pointE = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (5d , 5d )));
135
+ LabeledInstance pointF = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (6d , 3d )));
136
+
137
+ classifier .setK (3 );
138
+ classifier .train (Arrays .asList (pointA , pointB , pointC , pointD , pointE , pointF ), 3 );
139
+
140
+ double scoreExpected = Math .sqrt (29 )/100 ;
141
+ Prediction predictedInstance = new Prediction (positiveLabel , scoreExpected );
142
+
143
+ Prediction predictedInstance1 = classifier .predict (pointF );
144
+ Truth .assertThat (predictedInstance .getLabel ())
145
+ .isEqualTo (positiveLabel );
146
+ Truth .assertThat (predictedInstance .getScore ())
147
+ .isEqualTo (scoreExpected );
148
+ }
149
+
150
+
151
+
116
152
@ Test (expected = IllegalArgumentException .class )
117
153
public void when_similarNeighbors_its_called_with_null_neighbors_args_it_should_raise_an_exception (){
118
154
classifier .similarNeighbors (null , 10 );
119
155
}
120
156
157
+ @ Test (expected = IllegalArgumentException .class )
158
+ public void it_should_raise_an_exception_when_predict_list_method_its_called (){
159
+ classifier .predict (Collections .emptyList ());
160
+ }
121
161
}
122
162
123
163
0 commit comments