Skip to content

Commit d81fc86

Browse files
authored
feat(postgrest): add read-only mode for RPC (#600)
* feat(postgrest): add read-only mode for RPC * add integration tests
1 parent 97e19e8 commit d81fc86

10 files changed

+129
-69
lines changed

Sources/PostgREST/PostgrestClient.swift

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ConcurrencyExtras
22
import Foundation
3-
import Helpers
43
import HTTPTypes
4+
import Helpers
55

66
public typealias PostgrestError = Helpers.PostgrestError
77
public typealias HTTPError = Helpers.HTTPError
@@ -129,31 +129,69 @@ public final class PostgrestClient: Sendable {
129129
/// - Parameters:
130130
/// - fn: The function name to call.
131131
/// - params: The parameters to pass to the function call.
132+
/// - head: When set to `true`, `data`, will not be returned. Useful if you only need the count.
133+
/// - get: When set to `true`, the function will be called with read-only access mode.
132134
/// - count: Count algorithm to use to count rows returned by the function. Only applicable for [set-returning functions](https://www.postgresql.org/docs/current/functions-srf.html).
133135
public func rpc(
134136
_ fn: String,
135137
params: some Encodable & Sendable,
138+
head: Bool = false,
139+
get: Bool = false,
136140
count: CountOption? = nil
137141
) throws -> PostgrestFilterBuilder {
138-
try PostgrestRpcBuilder(
142+
let method: HTTPTypes.HTTPRequest.Method
143+
var url = configuration.url.appendingPathComponent("rpc/\(fn)")
144+
let bodyData = try configuration.encoder.encode(params)
145+
var body: Data?
146+
147+
if head || get {
148+
method = head ? .head : .get
149+
150+
guard let json = try JSONSerialization.jsonObject(with: bodyData) as? [String: Any] else {
151+
throw PostgrestError(
152+
message: "Params should be a key-value type when using `GET` or `HEAD` options.")
153+
}
154+
155+
for (key, value) in json {
156+
let formattedValue = (value as? [Any]).map(cleanFilterArray) ?? String(describing: value)
157+
url.appendQueryItems([URLQueryItem(name: key, value: formattedValue)])
158+
}
159+
160+
} else {
161+
method = .post
162+
body = bodyData
163+
}
164+
165+
var request = HTTPRequest(
166+
url: url,
167+
method: method,
168+
headers: HTTPFields(configuration.headers),
169+
body: params is NoParams ? nil : body
170+
)
171+
172+
if let count {
173+
request.headers[.prefer] = "count=\(count.rawValue)"
174+
}
175+
176+
return PostgrestFilterBuilder(
139177
configuration: configuration,
140-
request: HTTPRequest(
141-
url: configuration.url.appendingPathComponent("rpc/\(fn)"),
142-
method: .post,
143-
headers: HTTPFields(configuration.headers)
144-
)
145-
).rpc(params: params, count: count)
178+
request: request
179+
)
146180
}
147181

148182
/// Perform a function call.
149183
/// - Parameters:
150184
/// - fn: The function name to call.
185+
/// - head: When set to `true`, `data`, will not be returned. Useful if you only need the count.
186+
/// - get: When set to `true`, the function will be called with read-only access mode.
151187
/// - count: Count algorithm to use to count rows returned by the function. Only applicable for [set-returning functions](https://www.postgresql.org/docs/current/functions-srf.html).
152188
public func rpc(
153189
_ fn: String,
190+
head: Bool = false,
191+
get: Bool = false,
154192
count: CountOption? = nil
155193
) throws -> PostgrestFilterBuilder {
156-
try rpc(fn, params: NoParams(), count: count)
194+
try rpc(fn, params: NoParams(), head: head, get: get, count: count)
157195
}
158196

159197
/// Select a schema to query or perform an function (rpc) call.
@@ -165,4 +203,14 @@ public final class PostgrestClient: Sendable {
165203
configuration.schema = schema
166204
return PostgrestClient(configuration: configuration)
167205
}
206+
207+
private func cleanFilterArray(_ filter: [Any]) -> String {
208+
"{\(filter.map { String(describing: $0) }.joined(separator: ","))}"
209+
}
210+
}
211+
212+
struct NoParams: Encodable {}
213+
214+
extension HTTPField.Name {
215+
static let prefer = Self("Prefer")!
168216
}

Sources/PostgREST/PostgrestFilterBuilder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22
import Helpers
33

4-
public class PostgrestFilterBuilder: PostgrestTransformBuilder {
4+
public class PostgrestFilterBuilder: PostgrestTransformBuilder, @unchecked Sendable {
55
public enum Operator: String, CaseIterable, Sendable {
66
case eq, neq, gt, gte, lt, lte, like, ilike, `is`, `in`, cs, cd, sl, sr, nxl, nxr, adj, ov, fts,
77
plfts, phfts, wfts

Sources/PostgREST/PostgrestQueryBuilder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22
import Helpers
33

4-
public final class PostgrestQueryBuilder: PostgrestBuilder {
4+
public final class PostgrestQueryBuilder: PostgrestBuilder, @unchecked Sendable {
55
/// Perform a SELECT query on the table or view.
66
/// - Parameters:
77
/// - columns: The columns to retrieve, separated by commas. Columns can be renamed when returned with `customName:columnName`

Sources/PostgREST/PostgrestRpcBuilder.swift

Lines changed: 0 additions & 48 deletions
This file was deleted.

Sources/PostgREST/PostgrestTransformBuilder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22
import Helpers
33

4-
public class PostgrestTransformBuilder: PostgrestBuilder {
4+
public class PostgrestTransformBuilder: PostgrestBuilder, @unchecked Sendable {
55
/// Perform a SELECT on the query result.
66
///
77
/// By default, `.insert()`, `.update()`, `.upsert()`, and `.delete()` do not return modified rows. By calling this method, modified rows are returned in `value`.

Tests/IntegrationTests/Potsgrest/PostgresTransformsTests.swift

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ final class PostgrestTransformsTests: XCTestCase {
1414
configuration: PostgrestClient.Configuration(
1515
url: URL(string: "\(DotEnv.SUPABASE_URL)/rest/v1")!,
1616
headers: [
17-
"apikey": DotEnv.SUPABASE_ANON_KEY,
17+
"apikey": DotEnv.SUPABASE_ANON_KEY
1818
],
1919
logger: nil
2020
)
2121
)
2222

2323
func testOrder() async throws {
24-
let res = try await client.from("users")
24+
let res =
25+
try await client.from("users")
2526
.select()
2627
.order("username", ascending: false)
2728
.execute().value as AnyJSON
@@ -63,7 +64,8 @@ final class PostgrestTransformsTests: XCTestCase {
6364
}
6465

6566
func testOrderOnMultipleColumns() async throws {
66-
let res = try await client.from("messages")
67+
let res =
68+
try await client.from("messages")
6769
.select()
6870
.order("channel_id", ascending: false)
6971
.order("username", ascending: false)
@@ -92,7 +94,8 @@ final class PostgrestTransformsTests: XCTestCase {
9294
}
9395

9496
func testLimit() async throws {
95-
let res = try await client.from("users")
97+
let res =
98+
try await client.from("users")
9699
.select()
97100
.limit(1)
98101
.execute().value as AnyJSON
@@ -113,7 +116,8 @@ final class PostgrestTransformsTests: XCTestCase {
113116
}
114117

115118
func testRange() async throws {
116-
let res = try await client.from("users")
119+
let res =
120+
try await client.from("users")
117121
.select()
118122
.range(from: 1, to: 3)
119123
.execute().value as AnyJSON
@@ -148,7 +152,8 @@ final class PostgrestTransformsTests: XCTestCase {
148152
}
149153

150154
func testSingle() async throws {
151-
let res = try await client.from("users")
155+
let res =
156+
try await client.from("users")
152157
.select()
153158
.limit(1)
154159
.single()
@@ -168,7 +173,8 @@ final class PostgrestTransformsTests: XCTestCase {
168173
}
169174

170175
func testSingleOnInsert() async throws {
171-
let res = try await client.from("users")
176+
let res =
177+
try await client.from("users")
172178
.insert(["username": "foo"])
173179
.select()
174180
.single()
@@ -193,7 +199,8 @@ final class PostgrestTransformsTests: XCTestCase {
193199
}
194200

195201
func testSelectOnInsert() async throws {
196-
let res = try await client.from("users")
202+
let res =
203+
try await client.from("users")
197204
.insert(["username": "foo"])
198205
.select("status")
199206
.execute().value as AnyJSON
@@ -215,7 +222,8 @@ final class PostgrestTransformsTests: XCTestCase {
215222
}
216223

217224
func testSelectOnRpc() async throws {
218-
let res = try await client.rpc("get_username_and_status", params: ["name_param": "supabot"])
225+
let res =
226+
try await client.rpc("get_username_and_status", params: ["name_param": "supabot"])
219227
.select("status")
220228
.execute().value as AnyJSON
221229

@@ -230,6 +238,29 @@ final class PostgrestTransformsTests: XCTestCase {
230238
}
231239
}
232240

241+
func testRpcWithArray() async throws {
242+
struct Params: Encodable {
243+
let arr: [Int]
244+
let index: Int
245+
}
246+
let res =
247+
try await client.rpc("get_array_element", params: Params(arr: [37, 420, 64], index: 2))
248+
.execute().value as Int
249+
XCTAssertEqual(res, 420)
250+
}
251+
252+
func testRpcWithReadOnlyAccessMode() async throws {
253+
struct Params: Encodable {
254+
let arr: [Int]
255+
let index: Int
256+
}
257+
let res =
258+
try await client.rpc(
259+
"get_array_element", params: Params(arr: [37, 420, 64], index: 2), get: true
260+
).execute().value as Int
261+
XCTAssertEqual(res, 420)
262+
}
263+
233264
func testCsv() async throws {
234265
let res = try await client.from("users").select().csv().execute().string()
235266
assertInlineSnapshot(of: res, as: .json) {

Tests/PostgRESTTests/BuildURLRequestTests.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,19 @@ final class BuildURLRequestTests: XCTestCase {
227227
.select()
228228
.gt("created_at", value: Date(timeIntervalSince1970: 0))
229229
},
230+
TestCase(name: "rpc call with head") { client in
231+
try client.rpc("sum", head: true)
232+
},
233+
TestCase(name: "rpc call with get") { client in
234+
try client.rpc("sum", get: true)
235+
},
236+
TestCase(name: "rpc call with get and params") { client in
237+
try client.rpc(
238+
"get_array_element",
239+
params: ["array": [37, 420, 64], "index": 2] as AnyJSON,
240+
get: true
241+
)
242+
},
230243
]
231244

232245
for testCase in testCases {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
curl \
2+
--header "Accept: application/json" \
3+
--header "Content-Type: application/json" \
4+
--header "X-Client-Info: postgrest-swift/x.y.z" \
5+
"https://example.supabase.co/rpc/get_array_element?array=%7B37,420,64%7D&index=2"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
curl \
2+
--header "Accept: application/json" \
3+
--header "Content-Type: application/json" \
4+
--header "X-Client-Info: postgrest-swift/x.y.z" \
5+
"https://example.supabase.co/rpc/sum"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
curl \
2+
--head \
3+
--header "Accept: application/json" \
4+
--header "Content-Type: application/json" \
5+
--header "X-Client-Info: postgrest-swift/x.y.z" \
6+
"https://example.supabase.co/rpc/sum"

0 commit comments

Comments
 (0)