Skip to content

Commit d6f0b09

Browse files
author
Prashanth Govindarajan
authored
Changes to support Arrow 2.0 and Spark 3.0 (#711)
1 parent 1bcd37b commit d6f0b09

File tree

5 files changed

+78
-33
lines changed

5 files changed

+78
-33
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ public void TestDataFrameVectorUdf()
290290
}
291291
}
292292

293-
[SkipIfSparkVersionIsGreaterOrEqualTo(Versions.V3_0_0)]
293+
[Fact]
294294
public void TestGroupedMapUdf()
295295
{
296296
DataFrame df = _spark
@@ -368,7 +368,7 @@ private static RecordBatch ArrowBasedCountCharacters(RecordBatch records)
368368
returnLength);
369369
}
370370

371-
[SkipIfSparkVersionIsGreaterOrEqualTo(Versions.V3_0_0)]
371+
[Fact]
372372
public void TestDataFrameGroupedMapUdf()
373373
{
374374
DataFrame df = _spark

src/csharp/Microsoft.Spark.UnitTest/WorkerFunctionTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ public void TestArrowWorkerFunctionForBool()
110110
new ArrowUdfWrapper<StringArray, BooleanArray, BooleanArray>(
111111
(strings, flags) => (BooleanArray)ToArrowArray(
112112
Enumerable.Range(0, strings.Length)
113-
.Select(i => flags.GetBoolean(i) || strings.GetString(i).Contains("true"))
113+
.Select(i => flags.GetValue(i).Value || strings.GetString(i).Contains("true"))
114114
.ToArray())).Execute);
115115

116116
IArrowArray[] input = new[]
@@ -120,10 +120,10 @@ public void TestArrowWorkerFunctionForBool()
120120
};
121121
var results = (BooleanArray)func.Func(input, new[] { 0, 1 });
122122
Assert.Equal(4, results.Length);
123-
Assert.True(results.GetBoolean(0));
124-
Assert.True(results.GetBoolean(1));
125-
Assert.True(results.GetBoolean(2));
126-
Assert.False(results.GetBoolean(3));
123+
Assert.True(results.GetValue(0).Value);
124+
Assert.True(results.GetValue(1).Value);
125+
Assert.True(results.GetValue(2).Value);
126+
Assert.False(results.GetValue(3).Value);
127127
}
128128

129129
/// <summary>

src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ public void TestDataFrameSqlCommandExecutorWithEmptyInput(
702702
IpcOptions ipcOptions)
703703
{
704704
var udfWrapper = new Sql.DataFrameUdfWrapper<ArrowStringDataFrameColumn, ArrowStringDataFrameColumn>(
705-
(strings) => strings.Apply(cur=> $"udf: {cur}"));
705+
(strings) => strings.Apply(cur => $"udf: {cur}"));
706706

707707
var command = new SqlCommand()
708708
{
@@ -874,15 +874,28 @@ await arrowWriter.WriteRecordBatchAsync(
874874
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
875875

876876
Assert.Equal(numRows, outputBatch.Length);
877-
Assert.Equal(2, outputBatch.ColumnCount);
877+
StringArray stringArray;
878+
Int64Array longArray;
879+
if (sparkVersion < new Version(Versions.V3_0_0))
880+
{
881+
Assert.Equal(2, outputBatch.ColumnCount);
882+
stringArray = (StringArray)outputBatch.Column(0);
883+
longArray = (Int64Array)outputBatch.Column(1);
884+
}
885+
else
886+
{
887+
Assert.Equal(1, outputBatch.ColumnCount);
888+
var structArray = (StructArray)outputBatch.Column(0);
889+
Assert.Equal(2, structArray.Fields.Count);
890+
stringArray = (StringArray)structArray.Fields[0];
891+
longArray = (Int64Array)structArray.Fields[1];
892+
}
878893

879-
var stringArray = (StringArray)outputBatch.Column(0);
880894
for (int i = 0; i < numRows; ++i)
881895
{
882896
Assert.Equal($"udf: {i}", stringArray.GetString(i));
883897
}
884898

885-
var longArray = (Int64Array)outputBatch.Column(1);
886899
for (int i = 0; i < numRows; ++i)
887900
{
888901
Assert.Equal(100 + i, longArray.Values[i]);
@@ -981,15 +994,28 @@ await arrowWriter.WriteRecordBatchAsync(
981994
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
982995

983996
Assert.Equal(numRows, outputBatch.Length);
984-
Assert.Equal(2, outputBatch.ColumnCount);
997+
StringArray stringArray;
998+
DoubleArray doubleArray;
999+
if (sparkVersion < new Version(Versions.V3_0_0))
1000+
{
1001+
Assert.Equal(2, outputBatch.ColumnCount);
1002+
stringArray = (StringArray)outputBatch.Column(0);
1003+
doubleArray = (DoubleArray)outputBatch.Column(1);
1004+
}
1005+
else
1006+
{
1007+
Assert.Equal(1, outputBatch.ColumnCount);
1008+
var structArray = (StructArray)outputBatch.Column(0);
1009+
Assert.Equal(2, structArray.Fields.Count);
1010+
stringArray = (StringArray)structArray.Fields[0];
1011+
doubleArray = (DoubleArray)structArray.Fields[1];
1012+
}
9851013

986-
var stringArray = (StringArray)outputBatch.Column(0);
9871014
for (int i = 0; i < numRows; ++i)
9881015
{
9891016
Assert.Equal($"udf: {i}", stringArray.GetString(i));
9901017
}
9911018

992-
var doubleArray = (DoubleArray)outputBatch.Column(1);
9931019
for (int i = 0; i < numRows; ++i)
9941020
{
9951021
Assert.Equal(100 + i, doubleArray.Values[i]);

src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ private CommandExecutorStat ExecuteArrowSqlCommand(
422422

423423
var recordBatch = new RecordBatch(resultSchema, results, numEntries);
424424

425-
// TODO: Remove sync-over-async once WriteRecordBatch exists.
426-
writer.WriteRecordBatchAsync(recordBatch).GetAwaiter().GetResult();
425+
writer.WriteRecordBatch(recordBatch);
427426
}
428427

429428
WriteEnd(outputStream, ipcOptions);
@@ -468,8 +467,7 @@ private CommandExecutorStat ExecuteDataFrameSqlCommand(
468467
new ArrowStreamWriter(outputStream, result.Schema, leaveOpen: true, ipcOptions);
469468
}
470469

471-
// TODO: Remove sync-over-async once WriteRecordBatch exists.
472-
writer.WriteRecordBatchAsync(result).GetAwaiter().GetResult();
470+
writer.WriteRecordBatch(result);
473471
}
474472
}
475473

@@ -737,6 +735,29 @@ protected internal override CommandExecutorStat ExecuteCore(
737735
return ExecuteArrowGroupedMapCommand(inputStream, outputStream, commands);
738736
}
739737

738+
private RecordBatch WrapColumnsInStructIfApplicable(RecordBatch batch)
739+
{
740+
if (_version >= new Version(Versions.V3_0_0))
741+
{
742+
var fields = new Field[batch.Schema.Fields.Count];
743+
for (int i = 0; i < batch.Schema.Fields.Count; ++i)
744+
{
745+
fields[i] = batch.Schema.GetFieldByIndex(i);
746+
}
747+
748+
var structType = new StructType(fields);
749+
var structArray = new StructArray(
750+
structType,
751+
batch.Length,
752+
batch.Arrays.Cast<Apache.Arrow.Array>(),
753+
ArrowBuffer.Empty);
754+
Schema schema = new Schema.Builder().Field(new Field("Struct", structType, false)).Build();
755+
return new RecordBatch(schema, new[] { structArray }, batch.Length);
756+
}
757+
758+
return batch;
759+
}
760+
740761
private CommandExecutorStat ExecuteArrowGroupedMapCommand(
741762
Stream inputStream,
742763
Stream outputStream,
@@ -754,19 +775,19 @@ private CommandExecutorStat ExecuteArrowGroupedMapCommand(
754775
ArrowStreamWriter writer = null;
755776
foreach (RecordBatch input in GetInputIterator(inputStream))
756777
{
757-
RecordBatch result = worker.Func(input);
778+
RecordBatch batch = worker.Func(input);
758779

759-
int numEntries = result.Length;
780+
RecordBatch final = WrapColumnsInStructIfApplicable(batch);
781+
int numEntries = final.Length;
760782
stat.NumEntriesProcessed += numEntries;
761783

762784
if (writer == null)
763785
{
764786
writer =
765-
new ArrowStreamWriter(outputStream, result.Schema, leaveOpen: true, ipcOptions);
787+
new ArrowStreamWriter(outputStream, final.Schema, leaveOpen: true, ipcOptions);
766788
}
767789

768-
// TODO: Remove sync-over-async once WriteRecordBatch exists.
769-
writer.WriteRecordBatchAsync(result).GetAwaiter().GetResult();
790+
writer.WriteRecordBatch(final);
770791
}
771792

772793
WriteEnd(outputStream, ipcOptions);
@@ -794,20 +815,21 @@ private CommandExecutorStat ExecuteDataFrameGroupedMapCommand(
794815
{
795816
FxDataFrame dataFrame = FxDataFrame.FromArrowRecordBatch(input);
796817
FxDataFrame resultDataFrame = worker.Func(dataFrame);
818+
797819
IEnumerable<RecordBatch> recordBatches = resultDataFrame.ToArrowRecordBatches();
798820

799-
foreach (RecordBatch result in recordBatches)
821+
foreach (RecordBatch batch in recordBatches)
800822
{
801-
stat.NumEntriesProcessed += result.Length;
823+
RecordBatch final = WrapColumnsInStructIfApplicable(batch);
824+
stat.NumEntriesProcessed += final.Length;
802825

803826
if (writer == null)
804827
{
805828
writer =
806-
new ArrowStreamWriter(outputStream, result.Schema, leaveOpen: true, ipcOptions);
829+
new ArrowStreamWriter(outputStream, final.Schema, leaveOpen: true, ipcOptions);
807830
}
808831

809-
// TODO: Remove sync-over-async once WriteRecordBatch exists.
810-
writer.WriteRecordBatchAsync(result).GetAwaiter().GetResult();
832+
writer.WriteRecordBatch(final);
811833
}
812834
}
813835

src/csharp/Microsoft.Spark/Microsoft.Spark.csproj

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
</ItemGroup>
2929

3030
<ItemGroup>
31-
<PackageReference Include="Apache.Arrow" Version="0.15.1" />
31+
<PackageReference Include="Apache.Arrow" Version="2.0.0" />
3232
<PackageReference Include="Microsoft.CSharp" Version="4.5.0" />
3333
<PackageReference Include="Microsoft.Data.Analysis" Version="0.4.0" />
3434
<PackageReference Include="Newtonsoft.Json" Version="11.0.2" />
@@ -37,10 +37,7 @@
3737
</ItemGroup>
3838

3939
<ItemGroup>
40-
<Content Include="..\..\scala\microsoft-spark-*\target\microsoft-spark-*.jar"
41-
Link="jars\%(Filename)%(Extension)"
42-
Pack="true"
43-
PackagePath="jars\%(Filename)%(Extension)" />
40+
<Content Include="..\..\scala\microsoft-spark-*\target\microsoft-spark-*.jar" Link="jars\%(Filename)%(Extension)" Pack="true" PackagePath="jars\%(Filename)%(Extension)" />
4441
<Content Include="build\**" Pack="true" PackagePath="build" />
4542
</ItemGroup>
4643

0 commit comments

Comments
 (0)