Skip to content

Commit 85ae3ff

Browse files
Make serialization contain comm_id to respect jupyter comm handlers
1 parent 5617daf commit 85ae3ff

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/compiler/util/serializedCompiledScript.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ data class SerializedVariablesState(
5252

5353
@Serializable
5454
class SerializationReply(
55-
val cellId: Int = 1,
56-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
55+
val cell_id: Int = 1,
56+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
57+
val comm_id: String = ""
5758
)
5859

5960
@Serializable

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,15 @@ class SerializationRequest(
561561
val cellId: Int,
562562
val descriptorsState: Map<String, SerializedVariablesState>,
563563
val topLevelDescriptorName: String = "",
564-
val pathToDescriptor: List<String> = emptyList()
564+
val pathToDescriptor: List<String> = emptyList(),
565+
val commId: String = ""
565566
) : MessageContent()
566567

567568
@Serializable
568569
class SerializationReply(
569-
val cellId: Int = 1,
570-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
570+
val cell_id: Int = 1,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
572+
val comm_id: String = ""
571573
) : MessageContent()
572574

573575
@Serializable(MessageDataSerializer::class)

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,23 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
308308
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(mapOf())))
309309
}
310310
is CommOpen -> {
311-
if (!content.commId.equals(MessageType.SERIALIZATION_REQUEST.name, ignoreCase = true)) {
311+
if (!content.targetName.equals("kotlin_serialization", ignoreCase = true)) {
312312
send(makeReplyMessage(msg, MessageType.NONE))
313313
return
314314
}
315315
log.debug("Message type in CommOpen: $msg, ${msg.type}")
316316
val data = content.data ?: return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
317-
317+
if (data.isEmpty()) return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
318+
log.debug("Message data: $data")
318319
val messageContent = getVariablesDescriptorsFromJson(data)
319320
GlobalScope.launch(Dispatchers.Default) {
320321
repl.serializeVariables(
321322
messageContent.topLevelDescriptorName,
322323
messageContent.descriptorsState,
324+
content.commId,
323325
messageContent.pathToDescriptor
324326
) { result ->
325-
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result))
327+
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_MSG, content = result))
326328
}
327329
}
328330
}
@@ -343,7 +345,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
343345
is SerializationRequest -> {
344346
GlobalScope.launch(Dispatchers.Default) {
345347
if (content.topLevelDescriptorName.isNotEmpty()) {
346-
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, content.pathToDescriptor) { result ->
348+
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, commID = content.commId, content.pathToDescriptor) { result ->
347349
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
348350
}
349351
} else {

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ interface ReplForJupyter {
139139

140140
suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
141141

142-
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String> = emptyList(),
142+
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String = "", pathToDescriptor: List<String> = emptyList(),
143143
callback: (SerializationReply) -> Unit)
144144

145145
val homeDir: File?
@@ -568,9 +568,8 @@ class ReplForJupyterImpl(
568568
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
569569
}
570570

571-
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String>,
572-
callback: (SerializationReply) -> Unit) {
573-
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
571+
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String, pathToDescriptor: List<String>, callback: (SerializationReply) -> Unit) {
572+
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, comm_id = commID ,pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
574573
}
575574

576575
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
@@ -585,7 +584,7 @@ class ReplForJupyterImpl(
585584
}
586585
log.debug("Serialization cellID: $cellId")
587586
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
588-
return SerializationReply(cellId, resultMap)
587+
return SerializationReply(cellId, resultMap, args.comm_id)
589588
}
590589

591590

@@ -626,6 +625,7 @@ class ReplForJupyterImpl(
626625
var cellId: Int = -1,
627626
val topLevelVarName: String = "",
628627
val pathToDescriptor: List<String> = emptyList(),
628+
val comm_id: String = "",
629629
override val callback: (SerializationReply) -> Unit
630630
) : LockQueueArgs<SerializationReply>
631631

0 commit comments

Comments
 (0)