Skip to content
404 changes: 404 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,33 @@ struct ConnectionStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
/// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a
/// `Sync` message to the backend.
case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?)
/// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the
/// backend.
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
/// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend.
case succeedQueryContinuation(CheckedContinuation<Void, any Error>, sync: Bool)

/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
///
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
/// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`.
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)

/// Send a `CopyDone` and `Sync` message to the backend.
case sendCopyDoneAndSync

/// Send a `CopyFail` message to the backend with the given error message.
case sendCopyFail(message: String)

/// Fail the promise with the given error and close the connection.
///
/// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we
/// can't recover the connection because we can't send any messages to the backend, so we need to close it.
case failPromiseAndCloseConnection(EventLoopPromise<Void>, error: PSQLError, cleanupContext: CleanUpContext)

// --- streaming actions
// actions if query has requested next row but we are waiting for backend
Expand All @@ -107,6 +133,25 @@ struct ConnectionStateMachine {
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}

enum ChannelWritabilityChangedAction {
/// No action needs to be taken based on the writability change.
case none

/// Resume the given continuation successfully.
case succeedPromise(EventLoopPromise<Void>)
}

enum CheckBackendCanReceiveCopyDataAction {
/// Don't perform any action.
case none

/// Succeed the promise with a Void result.
case succeedPromise(EventLoopPromise<Void>)

/// Fail the promise with the given error.
case failPromise(EventLoopPromise<Void>, error: any Error)
}

private var state: State
private let requireBackendKeyData: Bool
private var taskQueue = CircularBuffer<PSQLTask>()
Expand Down Expand Up @@ -587,6 +632,8 @@ struct ConnectionStateMachine {
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .copyFrom(_, let triggerCopy):
return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
Expand Down Expand Up @@ -660,6 +707,16 @@ struct ConnectionStateMachine {
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
guard case .extendedQuery(var queryState, let connectionContext) = state else {
return .none
}
self.state = .modifying // avoid CoW
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

// MARK: - Running Queries -

Expand Down Expand Up @@ -752,10 +809,56 @@ struct ConnectionStateMachine {
return self.modify(with: action)
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponse
) -> ConnectionAction {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
}

self.state = .modifying // avoid CoW
let action = queryState.copyInResponseReceived(copyInResponse)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}


/// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data.
///
/// The promise may be failed if the backend indicated that it can't handle any more data by sending an
/// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer
/// should be aborted to avoid unnecessary work.
mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise<Void>) -> CheckBackendCanReceiveCopyDataAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyDone(continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyFail(message: message, continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func emptyQueryResponseReceived() -> ConnectionAction {
Expand All @@ -782,9 +885,10 @@ struct ConnectionStateMachine {

// MARK: Consumer

mutating func cancelQueryStream() -> ConnectionAction {
mutating func cancel() -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Tried to cancel stream without active query")
// We are not in a state in which we can cancel. Do nothing.
return .wait
}

self.state = .modifying // avoid CoW
Expand Down Expand Up @@ -866,14 +970,22 @@ struct ConnectionStateMachine {
.forwardRows,
.forwardStreamComplete,
.wait,
.read:
.read,
.triggerCopyData,
.sendCopyDoneAndSync,
.sendCopyFail,
.succeedQueryContinuation,
.failPromiseAndCloseConnection:
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")

case .evaluateErrorAtConnectionLevel:
return .closeConnectionAndCleanup(cleanupContext)

case .failQuery(let queryContext, with: let error):
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
case .failQuery(let promise, with: let error):
return .failQuery(promise, with: error, cleanupContext: cleanupContext)

case .failQueryContinuation(let continuation, with: let error, let sync):
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)

case .forwardStreamError(let error, let read):
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
Expand Down Expand Up @@ -1044,8 +1156,22 @@ extension ConnectionStateMachine {
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
case .failQueryContinuation(let continuation, with: let error, let sync):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
case .succeedQuery(let requestContext, with: let result):
return .succeedQuery(requestContext, with: result)
case .succeedQueryContinuation(let continuation, let sync):
return .succeedQueryContinuation(continuation, sync: sync)
case .triggerCopyData(let triggerCopy):
return .triggerCopyData(triggerCopy)
case .sendCopyDoneAndSync:
return .sendCopyDoneAndSync
case .sendCopyFail(message: let message):
return .sendCopyFail(message: message)
case .failPromiseAndCloseConnection(let promise, error: let error):
let cleanupContext = self.setErrorAndCreateCleanupContext(error)
return .failPromiseAndCloseConnection(promise, error: error, cleanupContext: cleanupContext)
case .forwardRows(let buffer):
return .forwardRows(buffer)
case .forwardStreamComplete(let buffer, let commandTag):
Expand Down
Loading
Loading