Skip to content

Commit 638898d

Browse files
authored
Implement NGram (ML Feature) (#734)
1 parent 064503b commit 638898d

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.ML.Feature;
7+
using Microsoft.Spark.Sql;
8+
using Microsoft.Spark.Sql.Types;
9+
using Microsoft.Spark.UnitTest.TestUtils;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
/// <summary>
15+
/// Test suite for <see cref="NGram"/> class.
16+
/// </summary>
17+
[Collection("Spark E2E Tests")]
18+
public class NGramTests : FeatureBaseTests<NGram>
19+
{
20+
private readonly SparkSession _spark;
21+
22+
public NGramTests(SparkFixture fixture) : base(fixture)
23+
{
24+
_spark = fixture.Spark;
25+
}
26+
27+
/// <summary>
28+
/// Test case to test the methods in <see cref="NGram"/> class.
29+
/// </summary>
30+
[Fact]
31+
public void TestNGram()
32+
{
33+
string expectedUid = "theUid";
34+
string expectedInputCol = "input_col";
35+
string expectedOutputCol = "output_col";
36+
int expectedN = 2;
37+
38+
DataFrame input = _spark.Sql("SELECT split('Hi I heard about Spark', ' ') as input_col");
39+
40+
NGram nGram = new NGram(expectedUid)
41+
.SetInputCol(expectedInputCol)
42+
.SetOutputCol(expectedOutputCol)
43+
.SetN(expectedN);
44+
45+
StructType outputSchema = nGram.TransformSchema(input.Schema());
46+
47+
DataFrame output = nGram.Transform(input);
48+
49+
Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
50+
Assert.Contains(outputSchema.Fields, (f => f.Name == expectedOutputCol));
51+
Assert.Equal(expectedInputCol, nGram.GetInputCol());
52+
Assert.Equal(expectedOutputCol, nGram.GetOutputCol());
53+
Assert.Equal(expectedN, nGram.GetN());
54+
55+
using (var tempDirectory = new TemporaryDirectory())
56+
{
57+
string savePath = Path.Join(tempDirectory.Path, "NGram");
58+
nGram.Save(savePath);
59+
60+
NGram loadedNGram = NGram.Load(savePath);
61+
Assert.Equal(nGram.Uid(), loadedNGram.Uid());
62+
}
63+
64+
Assert.Equal(expectedUid, nGram.Uid());
65+
66+
TestFeatureBase(nGram, "inputCol", "input_col");
67+
}
68+
}
69+
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop;
6+
using Microsoft.Spark.Interop.Ipc;
7+
using Microsoft.Spark.Sql;
8+
using Microsoft.Spark.Sql.Types;
9+
10+
namespace Microsoft.Spark.ML.Feature
11+
{
12+
/// <summary>
13+
/// Class <see cref="NGram"/> transformer that converts the input array of strings into
14+
/// an array of n-grams. Null values in the input array are ignored. It returns an array
15+
/// of n-grams where each n-gram is represented by a space-separated string of words.
16+
/// </summary>
17+
public class NGram : FeatureBase<NGram>, IJvmObjectReferenceProvider
18+
{
19+
private static readonly string s_nGramClassName =
20+
"org.apache.spark.ml.feature.NGram";
21+
22+
/// <summary>
23+
/// Create a <see cref="NGram"/> without any parameters.
24+
/// </summary>
25+
public NGram() : base(s_nGramClassName)
26+
{
27+
}
28+
29+
/// <summary>
30+
/// Create a <see cref="NGram"/> with a UID that is used to give the
31+
/// <see cref="NGram"/> a unique ID.
32+
/// </summary>
33+
/// <param name="uid">An immutable unique ID for the object and its derivatives.
34+
/// </param>
35+
public NGram(string uid) : base(s_nGramClassName, uid)
36+
{
37+
}
38+
39+
internal NGram(JvmObjectReference jvmObject) : base(jvmObject)
40+
{
41+
}
42+
43+
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
44+
45+
/// <summary>
46+
/// Gets the column that the <see cref="NGram"/> should read from.
47+
/// </summary>
48+
/// <returns>string, input column</returns>
49+
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol");
50+
51+
/// <summary>
52+
/// Sets the column that the <see cref="NGram"/> should read from.
53+
/// </summary>
54+
/// <param name="value">The name of the column to as the source</param>
55+
/// <returns>New <see cref="NGram"/> object</returns>
56+
public NGram SetInputCol(string value) => WrapAsNGram(_jvmObject.Invoke("setInputCol", value));
57+
58+
/// <summary>
59+
/// Gets the output column that the <see cref="NGram"/> writes.
60+
/// </summary>
61+
/// <returns>string, the output column</returns>
62+
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol");
63+
64+
/// <summary>
65+
/// Sets the output column that the <see cref="NGram"/> writes.
66+
/// </summary>
67+
/// <param name="value">The name of the new column</param>
68+
/// <returns>New <see cref="NGram"/> object</returns>
69+
public NGram SetOutputCol(string value) => WrapAsNGram(_jvmObject.Invoke("setOutputCol", value));
70+
71+
/// <summary>
72+
/// Gets N value for <see cref="NGram"/>.
73+
/// </summary>
74+
/// <returns>int, N value</returns>
75+
public int GetN() => (int)_jvmObject.Invoke("getN");
76+
77+
/// <summary>
78+
/// Sets N value for <see cref="NGram"/>.
79+
/// </summary>
80+
/// <param name="value">N value</param>
81+
/// <returns>New <see cref="NGram"/> object</returns>
82+
public NGram SetN(int value) => WrapAsNGram(_jvmObject.Invoke("setN", value));
83+
84+
/// <summary>
85+
/// Executes the <see cref="NGram"/> and transforms the DataFrame to include the new
86+
/// column.
87+
/// </summary>
88+
/// <param name="source">The DataFrame to transform</param>
89+
/// <returns>
90+
/// New <see cref="DataFrame"/> object with the source <see cref="DataFrame"/> transformed.
91+
/// </returns>
92+
public DataFrame Transform(DataFrame source) =>
93+
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));
94+
95+
/// <summary>
96+
/// Check transform validity and derive the output schema from the input schema.
97+
///
98+
/// This checks for validity of interactions between parameters during Transform and
99+
/// raises an exception if any parameter value is invalid.
100+
///
101+
/// Typical implementation should first conduct verification on schema change and parameter
102+
/// validity, including complex parameter interaction checks.
103+
/// </summary>
104+
/// <param name="value">
105+
/// The <see cref="StructType"/> of the <see cref="DataFrame"/> which will be transformed.
106+
/// </param>
107+
/// <returns>
108+
/// The <see cref="StructType"/> of the output schema that would have been derived from the
109+
/// input schema, if Transform had been called.
110+
/// </returns>
111+
public StructType TransformSchema(StructType value) =>
112+
new StructType(
113+
(JvmObjectReference)_jvmObject.Invoke(
114+
"transformSchema",
115+
DataType.FromJson(_jvmObject.Jvm, value.Json)));
116+
117+
/// <summary>
118+
/// Loads the <see cref="NGram"/> that was previously saved using Save.
119+
/// </summary>
120+
/// <param name="path">The path the previous <see cref="NGram"/> was saved to</param>
121+
/// <returns>New <see cref="NGram"/> object, loaded from path</returns>
122+
public static NGram Load(string path) =>
123+
WrapAsNGram(
124+
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
125+
s_nGramClassName,
126+
"load",
127+
path));
128+
129+
private static NGram WrapAsNGram(object obj) => new NGram((JvmObjectReference)obj);
130+
}
131+
}

0 commit comments

Comments
 (0)