Skip to content

Propagate Connection Closed Information up to top-level (fix #465) #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/AWSLambdaRuntime/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ internal struct LambdaHTTPServer {
await self.responsePool.push(
LocalServerResponse(
id: requestId,
status: .ok,
status: .accepted,
// the local server has no mecanism to collect headers set by the lambda function
headers: HTTPHeaders(),
body: body,
Expand Down
6 changes: 6 additions & 0 deletions Sources/AWSLambdaRuntime/Lambda.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public enum Lambda {
var logger = logger
do {
while !Task.isCancelled {

logger.trace("Waiting for next invocation")
let (invocation, writer) = try await runtimeClient.nextInvocation()
logger[metadataKey: "aws-request-id"] = "\(invocation.metadata.requestID)"

Expand Down Expand Up @@ -76,14 +78,18 @@ public enum Lambda {
logger: logger
)
)
logger.trace("Handler finished processing invocation")
} catch {
logger.trace("Handler failed processing invocation", metadata: ["Handler error": "\(error)"])
try await writer.reportError(error)
continue
}
logger.handler.metadata.removeValue(forKey: "aws-request-id")
}
} catch is CancellationError {
// don't allow cancellation error to propagate further
}

}

/// The default EventLoop the Lambda is scheduled on.
Expand Down
31 changes: 22 additions & 9 deletions Sources/AWSLambdaRuntime/LambdaRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,29 @@ public final class LambdaRuntime<Handler>: Sendable where Handler: StreamingLamb
let ip = String(ipAndPort[0])
guard let port = Int(ipAndPort[1]) else { throw LambdaRuntimeError(code: .invalidPort) }

try await LambdaRuntimeClient.withRuntimeClient(
configuration: .init(ip: ip, port: port),
eventLoop: self.eventLoop,
logger: self.logger
) { runtimeClient in
try await Lambda.runLoop(
runtimeClient: runtimeClient,
handler: handler,
do {
try await LambdaRuntimeClient.withRuntimeClient(
configuration: .init(ip: ip, port: port),
eventLoop: self.eventLoop,
logger: self.logger
)
) { runtimeClient in
try await Lambda.runLoop(
runtimeClient: runtimeClient,
handler: handler,
logger: self.logger
)
}
} catch {
// catch top level errors that have not been handled until now
// this avoids the runtime to crash and generate a backtrace
self.logger.error("LambdaRuntime.run() failed with error", metadata: ["error": "\(error)"])
if let error = error as? LambdaRuntimeError,
error.code != .connectionToControlPlaneLost
{
// if the error is a LambdaRuntimeError but not a connection error,
// we rethrow it to preserve existing behaviour
throw error
}
}

} else {
Expand Down
65 changes: 57 additions & 8 deletions Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,26 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
NIOLoopBound<LambdaChannelHandler<LambdaRuntimeClient>>, any Error
>

private enum ConnectionState {
private enum ConnectionState: Equatable {
case disconnected
case lostConnection
case connecting([ConnectionContinuation])
case connected(Channel, LambdaChannelHandler<LambdaRuntimeClient>)

static func == (lhs: ConnectionState, rhs: ConnectionState) -> Bool {
switch (lhs, rhs) {
case (.disconnected, .disconnected):
return true
case (.connecting, .connecting):
return true
case (.connected, .connected):
return true
case (.lostConnection, .lostConnection):
return true
default:
return false
}
}
}

enum LambdaState {
Expand All @@ -97,6 +113,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
private let configuration: Configuration

private var connectionState: ConnectionState = .disconnected

private var lambdaState: LambdaState = .idle(previousRequestID: nil)
private var closingState: ClosingState = .notClosing

Expand All @@ -118,10 +135,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
} catch {
result = .failure(error)
}

await runtime.close()

//try? await runtime.close()
return try result.get()
}

Expand Down Expand Up @@ -157,18 +171,27 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {

case .connected(let channel, _):
channel.close(mode: .all, promise: nil)

case .lostConnection:
continuation.resume()
}
}
}

@usableFromInline
func nextInvocation() async throws -> (Invocation, Writer) {
try await withTaskCancellationHandler {

if self.connectionState == .lostConnection {
throw LambdaRuntimeError(code: .connectionToControlPlaneLost)
}

return try await withTaskCancellationHandler {
switch self.lambdaState {
case .idle:
self.lambdaState = .waitingForNextInvocation
let handler = try await self.makeOrGetConnection()
let invocation = try await handler.nextInvocation()

guard case .waitingForNextInvocation = self.lambdaState else {
fatalError("Invalid state: \(self.lambdaState)")
}
Expand Down Expand Up @@ -262,7 +285,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {

private func channelClosed(_ channel: any Channel) {
switch (self.connectionState, self.closingState) {
case (_, .closed):
case (_, .closed), (.lostConnection, _):
fatalError("Invalid state: \(self.connectionState), \(self.closingState)")

case (.disconnected, .notClosing):
Expand All @@ -283,7 +306,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
case (.connecting(let array), .notClosing):
self.connectionState = .disconnected
for continuation in array {
continuation.resume(throwing: LambdaRuntimeError(code: .lostConnectionToControlPlane))
continuation.resume(throwing: LambdaRuntimeError(code: .connectionToControlPlaneLost))
}

case (.connecting(let array), .closing(let continuation)):
Expand Down Expand Up @@ -322,6 +345,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
return loopBound.value
case .connected(_, let handler):
return handler

case .lostConnection:
throw LambdaRuntimeError(code: .connectionToControlPlaneLost)
}

let bootstrap = ClientBootstrap(group: self.eventLoop)
Expand Down Expand Up @@ -363,7 +389,19 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
)
channel.closeFuture.whenComplete { result in
self.assumeIsolated { runtimeClient in

// resume any pending continuation on the handler
if case .connected(_, let handler) = runtimeClient.connectionState {
if case .connected(_, let lambdaState) = handler.state {
if case .waitingForNextInvocation(let continuation) = lambdaState {
continuation.resume(throwing: LambdaRuntimeError(code: .connectionToControlPlaneLost))
}
}
}

// close the channel
runtimeClient.channelClosed(channel)
runtimeClient.connectionState = .lostConnection
}
}

Expand All @@ -380,8 +418,12 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
}
}
return handler
case .lostConnection:
// this should never happen
fatalError("Lost connection to Lambda service")
}
} catch {

switch self.connectionState {
case .disconnected, .connected:
fatalError("Unexpected state: \(self.connectionState)")
Expand All @@ -394,6 +436,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
}
}
throw error
case .lostConnection:
// this should never happen
fatalError("Lost connection to Lambda service")
}
}
}
Expand Down Expand Up @@ -431,6 +476,9 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate {

isolated.connectionState = .disconnected

case .lostConnection:
// this should never happen
fatalError("Lost connection to Lambda service")
}
}
}
Expand Down Expand Up @@ -463,7 +511,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
}
}

private var state: State = .disconnected
var state: State = .disconnected
private var lastError: Error?
private var reusableErrorBuffer: ByteBuffer?
private let logger: Logger
Expand Down Expand Up @@ -885,6 +933,7 @@ extension LambdaChannelHandler: ChannelInboundHandler {
// fail any pending responses with last error or assume peer disconnected
switch self.state {
case .connected(_, .waitingForNextInvocation(let continuation)):
self.state = .disconnected
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
default:
break
Expand Down
1 change: 0 additions & 1 deletion Sources/AWSLambdaRuntime/LambdaRuntimeError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ package struct LambdaRuntimeError: Error {

case writeAfterFinishHasBeenSent
case finishAfterFinishHasBeenSent
case lostConnectionToControlPlane
case unexpectedStatusCodeForRequest

case nextInvocationMissingHeaderRequestID
Expand Down
2 changes: 1 addition & 1 deletion Sources/MockServer/MockHTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ struct HttpServer {
} else if requestHead.uri.hasSuffix("/response") {
responseStatus = .accepted
} else if requestHead.uri.hasSuffix("/error") {
responseStatus = .ok
responseStatus = .accepted
} else {
responseStatus = .notFound
}
Expand Down
Loading