Skip to content

Commit 55a8ebc

Browse files
authored
Valid credential threshold (#652)
1 parent 181c87d commit 55a8ebc

File tree

2 files changed

+105
-9
lines changed

2 files changed

+105
-9
lines changed

Sources/SotoCore/Credential/RotatingCredentialProvider.swift

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,46 @@ import struct Foundation.TimeInterval
2424
///
2525
/// Used for wrapping another credential provider whose `getCredential` method returns an `ExpiringCredential`.
2626
/// If no credential is available, or the current credentials are going to expire in the near future the wrapped credential provider
27-
/// `getCredential` is called again. If current credentials have not expired they are returned otherwise we wait on new
27+
/// `getCredential` is called again. If current credentials have not expired (within a threshold) they are returned otherwise we wait on new
2828
/// credentials being provided.
2929
public final class RotatingCredentialProvider: CredentialProvider {
3030
let expiringCredential: ExpiringValue<Credential>
31-
31+
let validCredentialThreshold: TimeInterval
3232
public let provider: CredentialProvider
3333

34-
public init(context: CredentialProviderFactory.Context, provider: CredentialProvider, remainingTokenLifetimeForUse: TimeInterval? = nil) {
34+
/// Initialize RotatingCredentialProvider
35+
/// - Parameters:
36+
/// - context: Context used to create this credential provider
37+
/// - provider: Credential provider to request credentials from
38+
/// - remainingTokenLifetimeForUse: How near to expiration, before we request new credentials
39+
public init(
40+
context: CredentialProviderFactory.Context,
41+
provider: CredentialProvider,
42+
remainingTokenLifetimeForUse: TimeInterval? = nil
43+
) {
44+
self.provider = provider
45+
self.validCredentialThreshold = 15
46+
self.expiringCredential = .init(threshold: remainingTokenLifetimeForUse ?? 165) {
47+
try await Self.getCredentialAndExpiration(provider: provider, validCredentialThreshold: 15, logger: context.logger)
48+
}
49+
}
50+
51+
/// Initialize RotatingCredentialProvider
52+
/// - Parameters:
53+
/// - context: Context used to create this credential provider
54+
/// - provider: Credential provider to request credentials from
55+
/// - remainingTokenLifetimeForUse: How near to expiration, before we request new credentials
56+
/// - validCredentialThreshold: How near to expiration do we return the current credentials
57+
public init(
58+
context: CredentialProviderFactory.Context,
59+
provider: CredentialProvider,
60+
remainingTokenLifetimeForUse: TimeInterval? = nil,
61+
validCredentialThreshold: TimeInterval
62+
) {
3563
self.provider = provider
64+
self.validCredentialThreshold = validCredentialThreshold
3665
self.expiringCredential = .init(threshold: remainingTokenLifetimeForUse ?? 3 * 60) {
37-
try await Self.getCredentialAndExpiration(provider: provider, logger: context.logger)
66+
try await Self.getCredentialAndExpiration(provider: provider, validCredentialThreshold: validCredentialThreshold, logger: context.logger)
3867
}
3968
}
4069

@@ -51,17 +80,34 @@ public final class RotatingCredentialProvider: CredentialProvider {
5180

5281
public func getCredential(logger: Logger) async throws -> Credential {
5382
try await self.expiringCredential.getValue {
54-
try await Self.getCredentialAndExpiration(provider: self.provider, logger: logger)
83+
try await Self.getCredentialAndExpiration(
84+
provider: self.provider,
85+
validCredentialThreshold: self.validCredentialThreshold,
86+
logger: logger
87+
)
5588
}
5689
}
5790

58-
static func getCredentialAndExpiration(provider: CredentialProvider, logger: Logger) async throws -> (Credential, Date) {
59-
logger.debug("Refeshing AWS credentials", metadata: ["aws-credential-provider": .string("\(self)(\(provider.description))")])
91+
static func getCredentialAndExpiration(
92+
provider: CredentialProvider,
93+
validCredentialThreshold: TimeInterval,
94+
logger: Logger
95+
) async throws -> (Credential, Date) {
96+
logger.debug(
97+
"Refeshing AWS credentials",
98+
metadata: ["aws-credential-provider": .string("\(self)(\(provider.description))")]
99+
)
60100
try Task.checkCancellation()
61101
let credential = try await provider.getCredential(logger: logger)
62-
logger.debug("AWS credentials ready", metadata: ["aws-credential-provider": .string("\(self)(\(provider.description))")])
102+
logger.debug(
103+
"AWS credentials ready",
104+
metadata: ["aws-credential-provider": .string("\(self)(\(provider.description))")]
105+
)
63106
if let expiringCredential = credential as? ExpiringCredential {
64-
return (expiringCredential, expiringCredential.expiration)
107+
return (
108+
expiringCredential,
109+
expiringCredential.expiration.addingTimeInterval(-validCredentialThreshold)
110+
)
65111
} else {
66112
return (credential, Date.distantFuture)
67113
}

Tests/SotoCoreTests/Credential/RotatingCredentialProviderTests.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,56 @@ class RotatingCredentialProviderTests: XCTestCase {
104104
XCTAssertEqual(count.load(ordering: .sequentiallyConsistent), 1)
105105
}
106106

107+
func testGetCredentialAndGetNewOnesAsAboutToExpire() async throws {
108+
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
109+
defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) }
110+
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group))
111+
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }
112+
113+
let creds = [
114+
TestExpiringCredential(
115+
accessKeyId: "abc123",
116+
secretAccessKey: "abc123",
117+
sessionToken: "abc123",
118+
expiration: Date(timeIntervalSinceNow: 10)
119+
),
120+
TestExpiringCredential(
121+
accessKeyId: "def456",
122+
secretAccessKey: "def456",
123+
sessionToken: "def456",
124+
expiration: Date(timeIntervalSinceNow: 10)
125+
),
126+
]
127+
128+
let count = ManagedAtomic(0)
129+
let client = RotatingCredentialTestClient {
130+
let cred = creds[count.load(ordering: .sequentiallyConsistent)]
131+
count.wrappingIncrement(ordering: .sequentiallyConsistent)
132+
return cred
133+
}
134+
let context = CredentialProviderFactory.Context(httpClient: httpClient, logger: Logger(label: "soto"), options: .init())
135+
let provider = RotatingCredentialProvider(context: context, provider: client)
136+
137+
// get credentials for first time
138+
var returned = try await provider.getCredential(logger: Logger(label: "soto"))
139+
140+
XCTAssertEqual(returned.accessKeyId, creds[0].accessKeyId)
141+
XCTAssertEqual(returned.secretAccessKey, creds[0].secretAccessKey)
142+
XCTAssertEqual(returned.sessionToken, creds[0].sessionToken)
143+
XCTAssertEqual((returned as? TestExpiringCredential)?.expiration, creds[0].expiration)
144+
145+
// get credentials a second time, callback must not be hit
146+
returned = try await provider.getCredential(logger: Logger(label: "soto"))
147+
148+
XCTAssertEqual(returned.accessKeyId, creds[1].accessKeyId)
149+
XCTAssertEqual(returned.secretAccessKey, creds[1].secretAccessKey)
150+
XCTAssertEqual(returned.sessionToken, creds[1].sessionToken)
151+
XCTAssertEqual((returned as? TestExpiringCredential)?.expiration, creds[1].expiration)
152+
153+
// ensure callback was hit twice as we
154+
XCTAssertEqual(count.load(ordering: .sequentiallyConsistent), 2)
155+
}
156+
107157
func testGetCredentialHighlyConcurrent() async throws {
108158
let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
109159
defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) }

0 commit comments

Comments
 (0)