Skip to content

Commit 5eaefe7

Browse files
committed
Added support methods for spark k-nn graph building
1 parent 6d40913 commit 5eaefe7

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

src/main/java/info/debatty/java/stringsimilarity/StringProfile.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public double cosineSimilarity(StringProfile other) throws Exception {
5151
throw new Exception("Profiles were not created using the same kshingling object!");
5252
}
5353

54-
return this.vector.dotProduct(other.vector) / (this.vector.norm() * other.vector.norm());
54+
return this.vector.cosineSimilarity(other.vector);
5555
}
5656

5757
/**

src/main/java/info/debatty/java/utils/SparseBooleanVector.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
package info.debatty.java.utils;
2626

27+
import java.io.Serializable;
2728
import java.util.HashMap;
2829
import java.util.SortedSet;
2930
import java.util.TreeSet;
@@ -32,7 +33,7 @@
3233
*
3334
* @author tibo
3435
*/
35-
public class SparseBooleanVector {
36+
public class SparseBooleanVector implements Serializable {
3637

3738
/**
3839
* Indicates the positions that hold the value "true"

src/main/java/info/debatty/java/utils/SparseIntegerVector.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
package info.debatty.java.utils;
2626

27+
import java.io.Serializable;
2728
import java.util.HashMap;
2829
import java.util.SortedSet;
2930
import java.util.TreeSet;
@@ -32,7 +33,7 @@
3233
* Sparse vector of int, implemented using two arrays
3334
* @author Thibault Debatty
3435
*/
35-
public class SparseIntegerVector {
36+
public class SparseIntegerVector implements Serializable {
3637

3738
protected int[] keys;
3839
protected int[] values;
@@ -81,6 +82,29 @@ public SparseIntegerVector(int[] array) {
8182
}
8283
}
8384

85+
public double cosineSimilarity(SparseIntegerVector other) {
86+
double den = this.norm() * other.norm();
87+
double agg = 0;
88+
int i = 0;
89+
int j = 0;
90+
while (i < this.keys.length && j < other.keys.length) {
91+
int k1 = this.keys[i];
92+
int k2 = other.keys[j];
93+
94+
if (k1 == k2) {
95+
agg += this.values[i] * other.values[j] / den;
96+
i++;
97+
j++;
98+
99+
} else if (k1 < k2) {
100+
i++;
101+
} else {
102+
j++;
103+
}
104+
}
105+
return agg;
106+
}
107+
84108
/**
85109
*
86110
* @param other

src/test/java/info/debatty/java/utils/SparseIntegerVectorTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@ public void testDotProduct_doubleArr() {
8181
double result = instance.dotProduct(other);
8282
assertEquals(expResult, result, 0.0);
8383
}
84+
85+
86+
/**
87+
* Test of cosineSimilarity method, of class SparseIntegerVector.
88+
*/
89+
@Test
90+
public void testCosineSimilarity() {
91+
System.out.println("cosineSimilarity");
92+
SparseIntegerVector other = new SparseIntegerVector(new int[]{0, 1, 2, 3});
93+
SparseIntegerVector instance = new SparseIntegerVector(new int[]{1, 2, 0, 0});
94+
double expResult = instance.dotProduct(other) / (instance.norm() * other.norm());
95+
double result = instance.cosineSimilarity(other);
96+
assertEquals(expResult, result, 0.0);
97+
}
8498

8599
/**
86100
* Test of norm method, of class SparseIntegerVector.

0 commit comments

Comments
 (0)