Skip to content

Commit c96efde

Browse files
authored
Add assembly check when pipeline trying to get stages (#1049)
1 parent c2ee198 commit c96efde

17 files changed

+173
-94
lines changed

src/csharp/Microsoft.Spark/ML/Feature/Base.cs

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -119,29 +119,4 @@ protected static T WrapAsType<T>(JvmObjectReference reference)
119119
return (T)constructor.Invoke(new object[] { reference });
120120
}
121121
}
122-
123-
/// <summary>
124-
/// DotnetUtils is used to hold basic general helper functions that
125-
/// are used within ML scope.
126-
/// </summary>
127-
internal class DotnetUtils
128-
{
129-
/// <summary>
130-
/// Helper function for getting the exact class name from jvm object.
131-
/// </summary>
132-
/// <param name="jvmObject">The reference to object created in JVM.</param>
133-
/// <returns>A string Tuple2 of constructor class name and method name</returns>
134-
internal static (string, string) GetUnderlyingType(JvmObjectReference jvmObject)
135-
{
136-
var jvmClass = (JvmObjectReference)jvmObject.Invoke("getClass");
137-
var returnClass = (string)jvmClass.Invoke("getTypeName");
138-
string[] dotnetClass = returnClass.Replace("com.microsoft.azure.synapse.ml", "Synapse.ML")
139-
.Replace("org.apache.spark.ml", "Microsoft.Spark.ML")
140-
.Split(".".ToCharArray());
141-
string[] renameClass = dotnetClass.Select(x => char.ToUpper(x[0]) + x.Substring(1)).ToArray();
142-
string constructorClass = string.Join(".", renameClass);
143-
string methodName = "WrapAs" + dotnetClass[dotnetClass.Length - 1];
144-
return (constructorClass, methodName);
145-
}
146-
}
147122
}

src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ public class Bucketizer :
2424
IJavaMLWritable,
2525
IJavaMLReadable<Bucketizer>
2626
{
27-
private static readonly string s_bucketizerClassName =
27+
private static readonly string s_className =
2828
"org.apache.spark.ml.feature.Bucketizer";
2929

3030
/// <summary>
3131
/// Create a <see cref="Bucketizer"/> without any parameters
3232
/// </summary>
33-
public Bucketizer() : base(s_bucketizerClassName)
33+
public Bucketizer() : base(s_className)
3434
{
3535
}
3636

@@ -39,7 +39,7 @@ public Bucketizer() : base(s_bucketizerClassName)
3939
/// <see cref="Bucketizer"/> a unique ID
4040
/// </summary>
4141
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
42-
public Bucketizer(string uid) : base(s_bucketizerClassName, uid)
42+
public Bucketizer(string uid) : base(s_className, uid)
4343
{
4444
}
4545

@@ -163,7 +163,7 @@ public Bucketizer SetOutputCols(List<string> value) =>
163163
public static Bucketizer Load(string path) =>
164164
WrapAsBucketizer(
165165
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
166-
s_bucketizerClassName, "load", path));
166+
s_className, "load", path));
167167

168168
/// <summary>
169169
/// Executes the <see cref="Bucketizer"/> and transforms the DataFrame to include the new

src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ public class CountVectorizer :
1313
IJavaMLWritable,
1414
IJavaMLReadable<CountVectorizer>
1515
{
16-
private static readonly string s_countVectorizerClassName =
16+
private static readonly string s_className =
1717
"org.apache.spark.ml.feature.CountVectorizer";
1818

1919
/// <summary>
2020
/// Creates a <see cref="CountVectorizer"/> without any parameters.
2121
/// </summary>
22-
public CountVectorizer() : base(s_countVectorizerClassName)
22+
public CountVectorizer() : base(s_className)
2323
{
2424
}
2525

@@ -28,7 +28,7 @@ public CountVectorizer() : base(s_countVectorizerClassName)
2828
/// <see cref="CountVectorizer"/> a unique ID.
2929
/// </summary>
3030
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
31-
public CountVectorizer(string uid) : base(s_countVectorizerClassName, uid)
31+
public CountVectorizer(string uid) : base(s_className, uid)
3232
{
3333
}
3434

@@ -52,7 +52,7 @@ public override CountVectorizerModel Fit(DataFrame dataFrame) =>
5252
public static CountVectorizer Load(string path) =>
5353
WrapAsCountVectorizer((JvmObjectReference)
5454
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
55-
s_countVectorizerClassName, "load", path));
55+
s_className, "load", path));
5656

5757
/// <summary>
5858
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts

src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class CountVectorizerModel :
1515
IJavaMLWritable,
1616
IJavaMLReadable<CountVectorizerModel>
1717
{
18-
private static readonly string s_countVectorizerModelClassName =
18+
private static readonly string s_className =
1919
"org.apache.spark.ml.feature.CountVectorizerModel";
2020

2121
/// <summary>
@@ -24,7 +24,7 @@ public class CountVectorizerModel :
2424
/// <param name="vocabulary">The vocabulary to use</param>
2525
public CountVectorizerModel(List<string> vocabulary)
2626
: this(SparkEnvironment.JvmBridge.CallConstructor(
27-
s_countVectorizerModelClassName, vocabulary))
27+
s_className, vocabulary))
2828
{
2929
}
3030

@@ -36,7 +36,7 @@ public CountVectorizerModel(List<string> vocabulary)
3636
/// <param name="vocabulary">The vocabulary to use</param>
3737
public CountVectorizerModel(string uid, List<string> vocabulary)
3838
: this(SparkEnvironment.JvmBridge.CallConstructor(
39-
s_countVectorizerModelClassName, uid, vocabulary))
39+
s_className, uid, vocabulary))
4040
{
4141
}
4242

@@ -54,7 +54,7 @@ internal CountVectorizerModel(JvmObjectReference jvmObject) : base(jvmObject)
5454
public static CountVectorizerModel Load(string path) =>
5555
WrapAsCountVectorizerModel(
5656
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
57-
s_countVectorizerModelClassName, "load", path));
57+
s_className, "load", path));
5858

5959
/// <summary>
6060
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts

src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ public class FeatureHasher :
1515
IJavaMLWritable,
1616
IJavaMLReadable<FeatureHasher>
1717
{
18-
private static readonly string s_featureHasherClassName =
18+
private static readonly string s_className =
1919
"org.apache.spark.ml.feature.FeatureHasher";
2020

2121
/// <summary>
2222
/// Creates a <see cref="FeatureHasher"/> without any parameters.
2323
/// </summary>
24-
public FeatureHasher() : base(s_featureHasherClassName)
24+
public FeatureHasher() : base(s_className)
2525
{
2626
}
2727

@@ -30,7 +30,7 @@ public FeatureHasher() : base(s_featureHasherClassName)
3030
/// <see cref="FeatureHasher"/> a unique ID.
3131
/// </summary>
3232
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
33-
public FeatureHasher(string uid) : base(s_featureHasherClassName, uid)
33+
public FeatureHasher(string uid) : base(s_className, uid)
3434
{
3535
}
3636

@@ -48,7 +48,7 @@ internal FeatureHasher(JvmObjectReference jvmObject) : base(jvmObject)
4848
public static FeatureHasher Load(string path) =>
4949
WrapAsFeatureHasher(
5050
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
51-
s_featureHasherClassName,
51+
s_className,
5252
"load",
5353
path));
5454

src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ public class HashingTF :
2121
IJavaMLWritable,
2222
IJavaMLReadable<HashingTF>
2323
{
24-
private static readonly string s_hashingTfClassName =
24+
private static readonly string s_className =
2525
"org.apache.spark.ml.feature.HashingTF";
2626

2727
/// <summary>
2828
/// Create a <see cref="HashingTF"/> without any parameters
2929
/// </summary>
30-
public HashingTF() : base(s_hashingTfClassName)
30+
public HashingTF() : base(s_className)
3131
{
3232
}
3333

@@ -36,7 +36,7 @@ public HashingTF() : base(s_hashingTfClassName)
3636
/// <see cref="HashingTF"/> a unique ID
3737
/// </summary>
3838
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
39-
public HashingTF(string uid) : base(s_hashingTfClassName, uid)
39+
public HashingTF(string uid) : base(s_className, uid)
4040
{
4141
}
4242

@@ -52,7 +52,7 @@ internal HashingTF(JvmObjectReference jvmObject) : base(jvmObject)
5252
public static HashingTF Load(string path) =>
5353
WrapAsHashingTF(
5454
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
55-
s_hashingTfClassName, "load", path));
55+
s_className, "load", path));
5656

5757
/// <summary>
5858
/// Gets the binary toggle that controls term frequency counts

src/csharp/Microsoft.Spark/ML/Feature/IDF.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ public class IDF :
2222
IJavaMLWritable,
2323
IJavaMLReadable<IDF>
2424
{
25-
private static readonly string s_IDFClassName = "org.apache.spark.ml.feature.IDF";
25+
private static readonly string s_className = "org.apache.spark.ml.feature.IDF";
2626

2727
/// <summary>
2828
/// Create a <see cref="IDF"/> without any parameters
2929
/// </summary>
30-
public IDF() : base(s_IDFClassName)
30+
public IDF() : base(s_className)
3131
{
3232
}
3333

@@ -36,7 +36,7 @@ public IDF() : base(s_IDFClassName)
3636
/// <see cref="IDF"/> a unique ID
3737
/// </summary>
3838
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
39-
public IDF(string uid) : base(s_IDFClassName, uid)
39+
public IDF(string uid) : base(s_className, uid)
4040
{
4141
}
4242

@@ -103,7 +103,7 @@ public override IDFModel Fit(DataFrame source) =>
103103
public static IDF Load(string path)
104104
{
105105
return WrapAsIDF(
106-
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_IDFClassName, "load", path));
106+
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_className, "load", path));
107107
}
108108

109109
/// <summary>

src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ public class IDFModel :
1717
IJavaMLWritable,
1818
IJavaMLReadable<IDFModel>
1919
{
20-
private static readonly string s_IDFModelClassName =
20+
private static readonly string s_className =
2121
"org.apache.spark.ml.feature.IDFModel";
2222

2323
/// <summary>
2424
/// Create a <see cref="IDFModel"/> without any parameters
2525
/// </summary>
26-
public IDFModel() : base(s_IDFModelClassName)
26+
public IDFModel() : base(s_className)
2727
{
2828
}
2929

@@ -32,7 +32,7 @@ public IDFModel() : base(s_IDFModelClassName)
3232
/// <see cref="IDFModel"/> a unique ID
3333
/// </summary>
3434
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
35-
public IDFModel(string uid) : base(s_IDFModelClassName, uid)
35+
public IDFModel(string uid) : base(s_className, uid)
3636
{
3737
}
3838

@@ -96,7 +96,7 @@ public static IDFModel Load(string path)
9696
{
9797
return WrapAsIDFModel(
9898
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
99-
s_IDFModelClassName, "load", path));
99+
s_className, "load", path));
100100
}
101101

102102
/// <summary>

src/csharp/Microsoft.Spark/ML/Feature/NGram.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ public class NGram :
1919
IJavaMLWritable,
2020
IJavaMLReadable<NGram>
2121
{
22-
private static readonly string s_nGramClassName =
22+
private static readonly string s_className =
2323
"org.apache.spark.ml.feature.NGram";
2424

2525
/// <summary>
2626
/// Create a <see cref="NGram"/> without any parameters.
2727
/// </summary>
28-
public NGram() : base(s_nGramClassName)
28+
public NGram() : base(s_className)
2929
{
3030
}
3131

@@ -35,7 +35,7 @@ public NGram() : base(s_nGramClassName)
3535
/// </summary>
3636
/// <param name="uid">An immutable unique ID for the object and its derivatives.
3737
/// </param>
38-
public NGram(string uid) : base(s_nGramClassName, uid)
38+
public NGram(string uid) : base(s_className, uid)
3939
{
4040
}
4141

@@ -123,7 +123,7 @@ public override StructType TransformSchema(StructType value) =>
123123
public static NGram Load(string path) =>
124124
WrapAsNGram(
125125
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
126-
s_nGramClassName,
126+
s_className,
127127
"load",
128128
path));
129129

src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Reflection;
76
using Microsoft.Spark.Interop;
87
using Microsoft.Spark.Interop.Ipc;
98
using Microsoft.Spark.Sql;
9+
using Microsoft.Spark.Utils;
10+
using System.Collections.Generic;
1011

1112
namespace Microsoft.Spark.ML.Feature
1213
{
@@ -26,12 +27,12 @@ public class Pipeline :
2627
IJavaMLWritable,
2728
IJavaMLReadable<Pipeline>
2829
{
29-
private static readonly string s_pipelineClassName = "org.apache.spark.ml.Pipeline";
30+
private static readonly string s_className = "org.apache.spark.ml.Pipeline";
3031

3132
/// <summary>
3233
/// Creates a <see cref="Pipeline"/> without any parameters.
3334
/// </summary>
34-
public Pipeline() : base(s_pipelineClassName)
35+
public Pipeline() : base(s_className)
3536
{
3637
}
3738

@@ -40,7 +41,7 @@ public Pipeline() : base(s_pipelineClassName)
4041
/// <see cref="Pipeline"/> a unique ID.
4142
/// </summary>
4243
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
43-
public Pipeline(string uid) : base(s_pipelineClassName, uid)
44+
public Pipeline(string uid) : base(s_className, uid)
4445
{
4546
}
4647

@@ -57,24 +58,34 @@ internal Pipeline(JvmObjectReference jvmObject) : base(jvmObject)
5758
/// <returns><see cref="Pipeline"/> object</returns>
5859
public Pipeline SetStages(JavaPipelineStage[] value) =>
5960
WrapAsPipeline((JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
60-
"org.apache.spark.mllib.api.dotnet.MLUtils", "setPipelineStages",
61-
Reference, value.ToJavaArrayList()));
61+
"org.apache.spark.mllib.api.dotnet.MLUtils",
62+
"setPipelineStages",
63+
Reference,
64+
value.ToJavaArrayList()));
6265

6366
/// <summary>
6467
/// Get the stages of pipeline instance.
6568
/// </summary>
6669
/// <returns>A sequence of <see cref="JavaPipelineStage"/> stages</returns>
6770
public JavaPipelineStage[] GetStages()
6871
{
69-
JvmObjectReference[] jvmObjects = (JvmObjectReference[])Reference.Invoke("getStages");
70-
JavaPipelineStage[] result = new JavaPipelineStage[jvmObjects.Length];
72+
var jvmObjects = (JvmObjectReference[])Reference.Invoke("getStages");
73+
var result = new JavaPipelineStage[jvmObjects.Length];
74+
Dictionary<string, Type> classMapping = JvmObjectUtils.ConstructJavaClassMapping(
75+
typeof(JavaPipelineStage),
76+
"s_className");
77+
7178
for (int i = 0; i < jvmObjects.Length; i++)
7279
{
73-
(string constructorClass, string methodName) = DotnetUtils.GetUnderlyingType(jvmObjects[i]);
74-
Type type = Type.GetType(constructorClass);
75-
MethodInfo method = type.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Static);
76-
result[i] = (JavaPipelineStage)method.Invoke(null, new object[] { jvmObjects[i] });
80+
if (JvmObjectUtils.TryConstructInstanceFromJvmObject(
81+
jvmObjects[i],
82+
classMapping,
83+
out JavaPipelineStage instance))
84+
{
85+
result[i] = instance;
86+
}
7787
}
88+
7889
return result;
7990
}
8091

@@ -91,7 +102,7 @@ override public PipelineModel Fit(DataFrame dataset) =>
91102
/// <param name="path">The path the previous <see cref="Pipeline"/> was saved to</param>
92103
/// <returns>New <see cref="Pipeline"/> object, loaded from path.</returns>
93104
public static Pipeline Load(string path) => WrapAsPipeline(
94-
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_pipelineClassName, "load", path));
105+
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_className, "load", path));
95106

96107
/// <summary>
97108
/// Saves the object so that it can be loaded later using Load. Note that these objects

0 commit comments

Comments
 (0)