Skip to content

Commit 807d8a0

Browse files
committed
Add connection tests
1 parent 39de666 commit 807d8a0

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ extension PostgresFrontendMessage {
217217
preconditionFailure("TODO: Unimplemented")
218218
case .saslResponse:
219219
preconditionFailure("TODO: Unimplemented")
220+
case .query:
221+
guard let query = buffer.readNullTerminatedString() else {
222+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
223+
}
224+
return .query(.init(query: query))
220225
case .sync:
221226
return .sync
222227
case .terminate:

Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ enum PostgresFrontendMessage: Equatable {
5959
}
6060
}
6161

62+
struct Query: Hashable {
63+
/// The query string.
64+
let query: String
65+
}
66+
6267
struct Parse: Hashable {
6368
/// The name of the destination prepared statement (an empty string selects the unnamed prepared statement).
6469
let preparedStatementName: String
@@ -179,6 +184,7 @@ enum PostgresFrontendMessage: Equatable {
179184
case saslInitialResponse(SASLInitialResponse)
180185
case saslResponse(SASLResponse)
181186
case sslRequest
187+
case query(Query)
182188
case sync
183189
case startup(Startup)
184190
case terminate
@@ -194,6 +200,7 @@ enum PostgresFrontendMessage: Equatable {
194200
case password
195201
case saslInitialResponse
196202
case saslResponse
203+
case query
197204
case sync
198205
case terminate
199206

@@ -217,6 +224,8 @@ enum PostgresFrontendMessage: Equatable {
217224
self = .saslInitialResponse
218225
case UInt8(ascii: "p"):
219226
self = .saslResponse
227+
case UInt8(ascii: "Q"):
228+
self = .query
220229
case UInt8(ascii: "S"):
221230
self = .sync
222231
case UInt8(ascii: "X"):
@@ -246,6 +255,8 @@ enum PostgresFrontendMessage: Equatable {
246255
return UInt8(ascii: "p")
247256
case .saslResponse:
248257
return UInt8(ascii: "p")
258+
case .query:
259+
return UInt8(ascii: "Q")
249260
case .sync:
250261
return UInt8(ascii: "S")
251262
case .terminate:
@@ -283,6 +294,8 @@ extension PostgresFrontendMessage {
283294
preconditionFailure("SSL requests don't have an identifier")
284295
case .startup:
285296
preconditionFailure("Startup messages don't have an identifier")
297+
case .query:
298+
return .query
286299
case .sync:
287300
return .sync
288301
case .terminate:

Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ class PSQLFrontendMessageTests: XCTestCase {
3131
XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self))
3232
XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length
3333
}
34-
34+
35+
func testEncodeQuery() {
36+
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
37+
let query = "SELECT * FROM foo"
38+
encoder.query(query)
39+
var byteBuffer = encoder.flushBuffer()
40+
41+
XCTAssertEqual(byteBuffer.readableBytes, 23)
42+
XCTAssertEqual(PostgresFrontendMessage.ID.query.rawValue, byteBuffer.readInteger(as: UInt8.self))
43+
XCTAssertEqual(22, byteBuffer.readInteger(as: Int32.self)) // payload length
44+
XCTAssertEqual([UInt8](query.utf8), byteBuffer.readBytes(length: 17))
45+
}
46+
3547
func testEncodeSync() {
3648
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
3749
encoder.sync()

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase {
317317
}
318318
}
319319

320+
func testCloseImmediatelyWithSimpleQuery() async throws {
321+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
322+
323+
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
324+
for _ in 1...2 {
325+
taskGroup.addTask {
326+
try await connection.__simpleQuery("SELECT 1;", logger: logger)
327+
}
328+
}
329+
330+
let query = try await channel.waitForSimpleQueryRequest()
331+
XCTAssertEqual(query.query, "SELECT 1;")
332+
333+
async let close: () = connection.close()
334+
335+
try await channel.closeFuture.get()
336+
XCTAssertEqual(channel.isActive, false)
337+
338+
try await close
339+
340+
while let taskResult = await taskGroup.nextResult() {
341+
switch taskResult {
342+
case .success:
343+
XCTFail("Expected queries to fail")
344+
case .failure(let failure):
345+
guard let error = failure as? PSQLError else {
346+
return XCTFail("Unexpected error type: \(failure)")
347+
}
348+
XCTAssertEqual(error.code, .clientClosedConnection)
349+
}
350+
}
351+
}
352+
}
353+
320354
func testIfServerJustClosesTheErrorReflectsThat() async throws {
321355
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
322356
let logger = self.logger
@@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase {
346380
}
347381
}
348382

383+
func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery() async throws {
384+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
385+
let logger = self.logger
386+
387+
async let response = try await connection.__simpleQuery("SELECT 1;", logger: logger)
388+
389+
let query = try await channel.waitForSimpleQueryRequest()
390+
XCTAssertEqual(query.query, "SELECT 1;")
391+
392+
try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() }
393+
try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() }
394+
395+
do {
396+
_ = try await response
397+
XCTFail("Expected to throw")
398+
} catch {
399+
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
400+
}
401+
402+
// retry on same connection
403+
404+
do {
405+
_ = try await connection.__simpleQuery("SELECT 1;", logger: self.logger)
406+
XCTFail("Expected to throw")
407+
} catch {
408+
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
409+
}
410+
}
411+
349412
struct TestPrepareStatement: PostgresPreparedStatement {
350413
static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1"
351414
typealias Row = String
@@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel {
692755
return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute)
693756
}
694757

758+
func waitForSimpleQueryRequest() async throws -> PostgresFrontendMessage.Query {
759+
let query = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
760+
guard case .query(let query) = query else {
761+
fatalError()
762+
}
763+
return query
764+
}
765+
695766
func waitForPrepareRequest() async throws -> PrepareRequest {
696767
let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
697768
let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)

0 commit comments

Comments
 (0)