@@ -15,6 +15,109 @@ import org.junit.Test
1515
1616class MatrixTest {
1717
18+ @Test
19+ @Throws(Exception ::class )
20+ fun testMass () {
21+ val tss = doubleArrayOf(10.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 14.0 , 10.0 , 10.0 )
22+ val dimsTss = longArrayOf(14 , 1 , 1 , 1 )
23+
24+ val query = doubleArrayOf(4.0 , 3.0 , 8.0 )
25+ val dimsQuery = longArrayOf(3 , 1 , 1 , 1 )
26+
27+ Array (tss, dimsTss).use { t ->
28+ Array (query, dimsQuery).use { q ->
29+
30+ val expectedDistance = doubleArrayOf(
31+ 1.732051 , 0.328954 , 1.210135 , 3.150851 , 3.245858 , 2.822044 ,
32+ 0.328954 , 1.210135 , 3.150851 , 0.248097 , 3.30187 , 2.82205 )
33+ val result = Matrix .mass(q, t)
34+ val distances = result.getData<DoubleArray >()
35+
36+ Assert .assertArrayEquals(expectedDistance, distances, 1e- 3 )
37+
38+ result.close()
39+ }
40+ }
41+
42+ }
43+
44+ @Test
45+ @Throws(Exception ::class )
46+ fun testMassMultiple () {
47+ val tss = doubleArrayOf(10.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 14.0 , 10.0 , 10.0 )
48+ val dimsTss = longArrayOf(7 , 2 , 1 , 1 )
49+
50+ val query = doubleArrayOf(10.0 , 10.0 , 11.0 , 11.0 , 10.0 , 11.0 , 10.0 , 10.0 )
51+ val dimsQuery = longArrayOf(4 , 2 , 1 , 1 )
52+
53+ Array (tss, dimsTss).use { t ->
54+ Array (query, dimsQuery).use { q ->
55+
56+ val expectedDistance = doubleArrayOf(
57+ 1.8388 , 0.8739 , 1.5307 , 3.6955 , 3.2660 , 3.4897 , 2.8284 , 1.2116 ,
58+ 1.5307 , 2.1758 , 2.5783 , 3.7550 , 2.8284 , 2.8284 , 3.2159 , 0.5020 )
59+ val result = Matrix .mass(q, t)
60+ val distances = result.getData<DoubleArray >()
61+
62+ Assert .assertArrayEquals(expectedDistance, distances, 1e- 3 )
63+
64+ result.close()
65+ }
66+ }
67+
68+ }
69+
70+ @Test
71+ @Throws(Exception ::class )
72+ fun testFindBestNOccurrences () {
73+ val tss = doubleArrayOf(10.0 , 10.0 , 11.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 , 10.0 , 10.0 , 11.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 , 12.0 , 11.0 , 10.0 , 10.0 , 11.0 )
74+ val dimsTss = longArrayOf(28 , 1 , 1 , 1 )
75+
76+ val query = doubleArrayOf(10.0 , 11.0 , 12.0 )
77+ val dimsQuery = longArrayOf(3 , 1 , 1 , 1 )
78+
79+ Array (tss, dimsTss).use { t ->
80+ Array (query, dimsQuery).use { q ->
81+ val result = Matrix .findBestNOccurrences(q, t, 1 )
82+ val distances = result[0 ].getData<DoubleArray >()
83+ val indexes = result[1 ].getData<IntArray >()
84+
85+ Assert .assertEquals(distances[0 ], 0.0 , DELTA )
86+ Assert .assertEquals(indexes[0 ], 7 )
87+
88+ result[0 ].close()
89+ result[1 ].close()
90+ }
91+ }
92+
93+ }
94+
95+ @Test
96+ @Throws(Exception ::class )
97+ fun testFindBestNOccurrencesMultipleQueries () {
98+ val tss = doubleArrayOf(10.0 , 10.0 , 11.0 , 11.0 , 10.0 , 11.0 , 10.0 , 10.0 , 11.0 , 11.0 , 10.0 , 11.0 , 10.0 , 10.0 , 11.0 , 10.0 , 10.0 , 11.0 , 10.0 , 11.0 , 11.0 , 10.0 , 11.0 , 11.0 , 14.0 , 10.0 , 11.0 , 10.0 )
99+ val dimsTss = longArrayOf(14 , 2 , 1 , 1 )
100+
101+ val query = doubleArrayOf(11.0 , 11.0 , 10.0 , 11.0 , 10.0 , 11.0 , 11.0 , 12.0 )
102+ val dimsQuery = longArrayOf(4 , 2 , 1 , 1 )
103+
104+ Array (tss, dimsTss).use { t ->
105+ Array (query, dimsQuery).use { q ->
106+ val result = Matrix .findBestNOccurrences(q, t, 4 )
107+
108+ val distance = getSingleValueDouble(result[0 ], 2 , 0 , 1 , 0 )
109+ Assert .assertEquals(distance, 1.83880 , 1e- 3 )
110+
111+ val index = getSingleValueInt(result[1 ], 3 , 1 , 0 , 0 )
112+ Assert .assertEquals(index.toLong(), 2 )
113+
114+ result[0 ].close()
115+ result[1 ].close()
116+ }
117+ }
118+
119+ }
120+
18121
19122 @Test
20123 @Throws(Exception ::class )
@@ -277,3 +380,27 @@ class MatrixTest {
277380 }
278381 }
279382}
383+
384+ private fun getSingleValueDouble (arr : Array , dim0 : Long , dim1 : Long , dim2 : Long , dim3 : Long ): Double {
385+ val data = arr.getData<DoubleArray >()
386+
387+ val dims4 = arr.dims
388+ var offset = dims4[0 ] * dims4[1 ] * dims4[2 ] * dim3
389+ offset + = dims4[0 ] * dims4[1 ] * dim2
390+ offset + = dims4[0 ] * dim1
391+ offset + = dim0
392+
393+ return data[offset.toInt()]
394+ }
395+
396+ private fun getSingleValueInt (arr : Array , dim0 : Long , dim1 : Long , dim2 : Long , dim3 : Long ): Int {
397+ val data = arr.getData<IntArray >()
398+
399+ val dims4 = arr.dims
400+ var offset = dims4[0 ] * dims4[1 ] * dims4[2 ] * dim3
401+ offset + = dims4[0 ] * dims4[1 ] * dim2
402+ offset + = dims4[0 ] * dim1
403+ offset + = dim0
404+
405+ return data[offset.toInt()]
406+ }
0 commit comments