Skip to content

Commit 04b9a33

Browse files
authored
Add microsoft-spark-3-1 JAR for the Spark 3.1 support preparation (#834)
1 parent f5d7159 commit 04b9a33

22 files changed

+2454
-0
lines changed

src/scala/microsoft-spark-3-1/pom.xml

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
2+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
<parent>
5+
<groupId>com.microsoft.scala</groupId>
6+
<artifactId>microsoft-spark</artifactId>
7+
<version>${microsoft-spark.version}</version>
8+
</parent>
9+
<artifactId>microsoft-spark-3-1_2.12</artifactId>
10+
<inceptionYear>2019</inceptionYear>
11+
<properties>
12+
<encoding>UTF-8</encoding>
13+
<scala.version>2.12.10</scala.version>
14+
<scala.binary.version>2.12</scala.binary.version>
15+
<spark.version>3.1.1</spark.version>
16+
</properties>
17+
18+
<dependencies>
19+
<dependency>
20+
<groupId>org.scala-lang</groupId>
21+
<artifactId>scala-library</artifactId>
22+
<version>${scala.version}</version>
23+
</dependency>
24+
<dependency>
25+
<groupId>org.apache.spark</groupId>
26+
<artifactId>spark-core_${scala.binary.version}</artifactId>
27+
<version>${spark.version}</version>
28+
<scope>provided</scope>
29+
</dependency>
30+
<dependency>
31+
<groupId>org.apache.spark</groupId>
32+
<artifactId>spark-sql_${scala.binary.version}</artifactId>
33+
<version>${spark.version}</version>
34+
<scope>provided</scope>
35+
</dependency>
36+
<dependency>
37+
<groupId>junit</groupId>
38+
<artifactId>junit</artifactId>
39+
<version>4.13.1</version>
40+
<scope>test</scope>
41+
</dependency>
42+
<dependency>
43+
<groupId>org.specs</groupId>
44+
<artifactId>specs</artifactId>
45+
<version>1.2.5</version>
46+
<scope>test</scope>
47+
</dependency>
48+
</dependencies>
49+
50+
<build>
51+
<sourceDirectory>src/main/scala</sourceDirectory>
52+
<testSourceDirectory>src/test/scala</testSourceDirectory>
53+
<plugins>
54+
<plugin>
55+
<groupId>org.scala-tools</groupId>
56+
<artifactId>maven-scala-plugin</artifactId>
57+
<version>2.15.2</version>
58+
<executions>
59+
<execution>
60+
<goals>
61+
<goal>compile</goal>
62+
<goal>testCompile</goal>
63+
</goals>
64+
</execution>
65+
</executions>
66+
<configuration>
67+
<scalaVersion>${scala.version}</scalaVersion>
68+
<args>
69+
<arg>-target:jvm-1.8</arg>
70+
<arg>-deprecation</arg>
71+
<arg>-feature</arg>
72+
</args>
73+
</configuration>
74+
</plugin>
75+
</plugins>
76+
</build>
77+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the .NET Foundation under one or more agreements.
3+
* The .NET Foundation licenses this file to you under the MIT license.
4+
* See the LICENSE file in the project root for more information.
5+
*/
6+
7+
package org.apache.spark.api.dotnet
8+
9+
import java.io.DataOutputStream
10+
11+
import org.apache.spark.internal.Logging
12+
13+
import scala.collection.mutable.Queue
14+
15+
/**
16+
* CallbackClient is used to communicate with the Dotnet CallbackServer.
17+
* The client manages and maintains a pool of open CallbackConnections.
18+
* Any callback request is delegated to a new CallbackConnection or
19+
* unused CallbackConnection.
20+
* @param address The address of the Dotnet CallbackServer
21+
* @param port The port of the Dotnet CallbackServer
22+
*/
23+
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
24+
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
25+
26+
private[this] var isShutdown: Boolean = false
27+
28+
final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
29+
getOrCreateConnection() match {
30+
case Some(connection) =>
31+
try {
32+
connection.send(callbackId, writeBody)
33+
addConnection(connection)
34+
} catch {
35+
case e: Exception =>
36+
logError(s"Error calling callback [callback id = $callbackId].", e)
37+
connection.close()
38+
throw e
39+
}
40+
case None => throw new Exception("Unable to get or create connection.")
41+
}
42+
43+
private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
44+
if (isShutdown) {
45+
logInfo("Cannot get or create connection while client is shutdown.")
46+
return None
47+
}
48+
49+
if (connectionPool.nonEmpty) {
50+
return Some(connectionPool.dequeue())
51+
}
52+
53+
Some(new CallbackConnection(serDe, address, port))
54+
}
55+
56+
private def addConnection(connection: CallbackConnection): Unit = synchronized {
57+
assert(connection != null)
58+
connectionPool.enqueue(connection)
59+
}
60+
61+
def shutdown(): Unit = synchronized {
62+
if (isShutdown) {
63+
logInfo("Shutdown called, but already shutdown.")
64+
return
65+
}
66+
67+
logInfo("Shutting down.")
68+
connectionPool.foreach(_.close)
69+
connectionPool.clear
70+
isShutdown = true
71+
}
72+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the .NET Foundation under one or more agreements.
3+
* The .NET Foundation licenses this file to you under the MIT license.
4+
* See the LICENSE file in the project root for more information.
5+
*/
6+
7+
package org.apache.spark.api.dotnet
8+
9+
import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
10+
import java.net.Socket
11+
12+
import org.apache.spark.internal.Logging
13+
14+
/**
15+
* CallbackConnection is used to process the callback communication
16+
* between the JVM and Dotnet. It uses a TCP socket to communicate with
17+
* the Dotnet CallbackServer and the socket is expected to be reused.
18+
* @param address The address of the Dotnet CallbackServer
19+
* @param port The port of the Dotnet CallbackServer
20+
*/
21+
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
22+
private[this] val socket: Socket = new Socket(address, port)
23+
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
24+
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
25+
26+
def send(
27+
callbackId: Int,
28+
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
29+
logInfo(s"Calling callback [callback id = $callbackId] ...")
30+
31+
try {
32+
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
33+
serDe.writeInt(outputStream, callbackId)
34+
35+
val byteArrayOutputStream = new ByteArrayOutputStream()
36+
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
37+
serDe.writeInt(outputStream, byteArrayOutputStream.size)
38+
byteArrayOutputStream.writeTo(outputStream);
39+
} catch {
40+
case e: Exception => {
41+
throw new Exception("Error writing to stream.", e)
42+
}
43+
}
44+
45+
logInfo(s"Signaling END_OF_STREAM.")
46+
try {
47+
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
48+
outputStream.flush()
49+
50+
val endOfStreamResponse = readFlag(inputStream)
51+
endOfStreamResponse match {
52+
case CallbackFlags.END_OF_STREAM =>
53+
logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
54+
case _ => {
55+
throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
56+
s"Received: $endOfStreamResponse")
57+
}
58+
}
59+
} catch {
60+
case e: Exception => {
61+
throw new Exception("Error while verifying end of stream.", e)
62+
}
63+
}
64+
}
65+
66+
def close(): Unit = {
67+
try {
68+
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
69+
outputStream.flush()
70+
} catch {
71+
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
72+
}
73+
74+
close(socket)
75+
close(outputStream)
76+
close(inputStream)
77+
}
78+
79+
private def close(s: Socket): Unit = {
80+
try {
81+
assert(s != null)
82+
s.close()
83+
} catch {
84+
case e: Exception => logInfo("Unable to close socket.", e)
85+
}
86+
}
87+
88+
private def close(c: Closeable): Unit = {
89+
try {
90+
assert(c != null)
91+
c.close()
92+
} catch {
93+
case e: Exception => logInfo("Unable to close closeable.", e)
94+
}
95+
}
96+
97+
private def readFlag(inputStream: DataInputStream): Int = {
98+
val callbackFlag = serDe.readInt(inputStream)
99+
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
100+
val exceptionMessage = serDe.readString(inputStream)
101+
throw new DotnetException(exceptionMessage)
102+
}
103+
callbackFlag
104+
}
105+
106+
private object CallbackFlags {
107+
val CLOSE: Int = -1
108+
val CALLBACK: Int = -2
109+
val DOTNET_EXCEPTION_THROWN: Int = -3
110+
val END_OF_STREAM: Int = -4
111+
}
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the .NET Foundation under one or more agreements.
3+
* The .NET Foundation licenses this file to you under the MIT license.
4+
* See the LICENSE file in the project root for more information.
5+
*/
6+
7+
package org.apache.spark.api.dotnet
8+
9+
import java.net.InetSocketAddress
10+
import java.util.concurrent.TimeUnit
11+
import io.netty.bootstrap.ServerBootstrap
12+
import io.netty.channel.nio.NioEventLoopGroup
13+
import io.netty.channel.socket.SocketChannel
14+
import io.netty.channel.socket.nio.NioServerSocketChannel
15+
import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
16+
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
17+
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
18+
import org.apache.spark.internal.Logging
19+
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
20+
import org.apache.spark.{SparkConf, SparkEnv}
21+
22+
/**
23+
* Netty server that invokes JVM calls based upon receiving messages from .NET.
24+
* The implementation mirrors the RBackend.
25+
*
26+
*/
27+
class DotnetBackend extends Logging {
28+
self => // for accessing the this reference in inner class(ChannelInitializer)
29+
private[this] var channelFuture: ChannelFuture = _
30+
private[this] var bootstrap: ServerBootstrap = _
31+
private[this] var bossGroup: EventLoopGroup = _
32+
private[this] val objectTracker = new JVMObjectTracker
33+
34+
@volatile
35+
private[dotnet] var callbackClient: Option[CallbackClient] = None
36+
37+
def init(portNumber: Int): Int = {
38+
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
39+
val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
40+
logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
41+
bossGroup = new NioEventLoopGroup(numBackendThreads)
42+
val workerGroup = bossGroup
43+
44+
bootstrap = new ServerBootstrap()
45+
.group(bossGroup, workerGroup)
46+
.channel(classOf[NioServerSocketChannel])
47+
48+
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
49+
def initChannel(ch: SocketChannel): Unit = {
50+
ch.pipeline()
51+
.addLast("encoder", new ByteArrayEncoder())
52+
.addLast(
53+
"frameDecoder",
54+
// maxFrameLength = 2G
55+
// lengthFieldOffset = 0
56+
// lengthFieldLength = 4
57+
// lengthAdjustment = 0
58+
// initialBytesToStrip = 4, i.e. strip out the length field itself
59+
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
60+
.addLast("decoder", new ByteArrayDecoder())
61+
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
62+
}
63+
})
64+
65+
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
66+
channelFuture.syncUninterruptibly()
67+
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
68+
}
69+
70+
private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
71+
callbackClient = callbackClient match {
72+
case Some(_) => throw new Exception("Callback client already set.")
73+
case None =>
74+
logInfo(s"Connecting to a callback server at $address:$port")
75+
Some(new CallbackClient(new SerDe(objectTracker), address, port))
76+
}
77+
}
78+
79+
private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
80+
callbackClient match {
81+
case Some(client) => client.shutdown()
82+
case None => logInfo("Callback server has already been shutdown.")
83+
}
84+
callbackClient = None
85+
}
86+
87+
def run(): Unit = {
88+
channelFuture.channel.closeFuture().syncUninterruptibly()
89+
}
90+
91+
def close(): Unit = {
92+
if (channelFuture != null) {
93+
// close is a local operation and should finish within milliseconds; timeout just to be safe
94+
channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
95+
channelFuture = null
96+
}
97+
if (bootstrap != null && bootstrap.config().group() != null) {
98+
bootstrap.config().group().shutdownGracefully()
99+
}
100+
if (bootstrap != null && bootstrap.config().childGroup() != null) {
101+
bootstrap.config().childGroup().shutdownGracefully()
102+
}
103+
bootstrap = null
104+
105+
objectTracker.clear()
106+
107+
// Send close to .NET callback server.
108+
shutdownCallbackClient()
109+
110+
// Shutdown the thread pool whose executors could still be running.
111+
ThreadPool.shutdown()
112+
}
113+
}

0 commit comments

Comments
 (0)