Skip to content

Commit d67b19f

Browse files
authored
Refactor how token refreshing works to be more resilient (#4819)
* Refactor how token refreshing works to be more resilient 1. ensure we do use the new token if it is not explicitly inhibited by the caller 2. eagerly refresh token if we know it is expired 3. allow refreshing a token multiple times if e.g. on bad connection or the environment has been slept and sufficient time has passed since the last refresh attempt Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> * Iterate Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> * Iterate Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> * Add exponential backoff Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> * Ensure no timing effects on `authedRequest` method call Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> * Iterate Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> --------- Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>
1 parent 6ec200a commit d67b19f

File tree

7 files changed

+249
-137
lines changed

7 files changed

+249
-137
lines changed

spec/unit/http-api/fetch.spec.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ describe("FetchHttpApi", () => {
356356
accessToken,
357357
refreshToken,
358358
});
359-
const result = await api.authedRequest(Method.Post, "/account/password");
359+
const result = await api.authedRequest(Method.Post, "/account/password", undefined, undefined, {
360+
headers: {},
361+
});
360362
expect(result).toEqual(okayResponse);
361363
expect(tokenRefreshFunction).toHaveBeenCalledWith(refreshToken);
362364

@@ -372,6 +374,7 @@ describe("FetchHttpApi", () => {
372374
const tokenRefreshFunction = jest.fn().mockResolvedValue({
373375
accessToken: newAccessToken,
374376
refreshToken: newRefreshToken,
377+
expiry: new Date(Date.now() + 1000),
375378
});
376379

377380
// fetch doesn't like our new or old tokens

spec/unit/oidc/tokenRefresher.spec.ts

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,12 @@ describe("OidcTokenRefresher", () => {
130130
method: "POST",
131131
});
132132

133-
expect(result).toEqual({
134-
accessToken: "new-access-token",
135-
refreshToken: "new-refresh-token",
136-
});
133+
expect(result).toEqual(
134+
expect.objectContaining({
135+
accessToken: "new-access-token",
136+
refreshToken: "new-refresh-token",
137+
}),
138+
);
137139
});
138140

139141
it("should persist the new tokens", async () => {
@@ -144,10 +146,12 @@ describe("OidcTokenRefresher", () => {
144146

145147
await refresher.doRefreshAccessToken("refresh-token");
146148

147-
expect(refresher.persistTokens).toHaveBeenCalledWith({
148-
accessToken: "new-access-token",
149-
refreshToken: "new-refresh-token",
150-
});
149+
expect(refresher.persistTokens).toHaveBeenCalledWith(
150+
expect.objectContaining({
151+
accessToken: "new-access-token",
152+
refreshToken: "new-refresh-token",
153+
}),
154+
);
151155
});
152156

153157
it("should only have one inflight refresh request at once", async () => {
@@ -189,21 +193,25 @@ describe("OidcTokenRefresher", () => {
189193

190194
// only one call to token endpoint
191195
expect(fetchMock).toHaveFetchedTimes(1, config.token_endpoint);
192-
expect(result1).toEqual({
193-
accessToken: "first-new-access-token",
194-
refreshToken: "first-new-refresh-token",
195-
});
196+
expect(result1).toEqual(
197+
expect.objectContaining({
198+
accessToken: "first-new-access-token",
199+
refreshToken: "first-new-refresh-token",
200+
}),
201+
);
196202
// same response
197203
expect(result1).toEqual(result2);
198204

199205
// call again after first request resolves
200206
const third = await refresher.doRefreshAccessToken("first-new-refresh-token");
201207

202208
// called token endpoint, got new tokens
203-
expect(third).toEqual({
204-
accessToken: "second-new-access-token",
205-
refreshToken: "second-new-refresh-token",
206-
});
209+
expect(third).toEqual(
210+
expect.objectContaining({
211+
accessToken: "second-new-access-token",
212+
refreshToken: "second-new-refresh-token",
213+
}),
214+
);
207215
});
208216

209217
it("should log and rethrow when token refresh fails", async () => {
@@ -261,10 +269,12 @@ describe("OidcTokenRefresher", () => {
261269
const result = await refresher.doRefreshAccessToken("first-new-refresh-token");
262270

263271
// called token endpoint, got new tokens
264-
expect(result).toEqual({
265-
accessToken: "second-new-access-token",
266-
refreshToken: "second-new-refresh-token",
267-
});
272+
expect(result).toEqual(
273+
expect.objectContaining({
274+
accessToken: "second-new-access-token",
275+
refreshToken: "second-new-refresh-token",
276+
}),
277+
);
268278
});
269279

270280
it("should throw TokenRefreshLogoutError when expired", async () => {

src/http-api/fetch.ts

Lines changed: 35 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ limitations under the License.
1818
* This is an internal module. See {@link MatrixHttpApi} for the public class.
1919
*/
2020

21-
import { checkObjectHasKeys, encodeParams } from "../utils.ts";
21+
import { checkObjectHasKeys, deepCopy, encodeParams } from "../utils.ts";
2222
import { type TypedEventEmitter } from "../models/typed-event-emitter.ts";
2323
import { Method } from "./method.ts";
24-
import { ConnectionError, MatrixError, TokenRefreshError, TokenRefreshLogoutError } from "./errors.ts";
24+
import { ConnectionError, MatrixError, TokenRefreshError } from "./errors.ts";
2525
import {
2626
HttpApiEvent,
2727
type HttpApiEventHandlerMap,
@@ -31,7 +31,7 @@ import {
3131
} from "./interface.ts";
3232
import { anySignal, parseErrorResponse, timeoutSignal } from "./utils.ts";
3333
import { type QueryDict } from "../utils.ts";
34-
import { singleAsyncExecution } from "../utils/decorators.ts";
34+
import { TokenRefresher, TokenRefreshOutcome } from "./refresh.ts";
3535

3636
interface TypedResponse<T> extends Response {
3737
json(): Promise<T>;
@@ -43,14 +43,9 @@ export type ResponseType<T, O extends IHttpOpts> = O extends { json: false }
4343
? T
4444
: TypedResponse<T>;
4545

46-
const enum TokenRefreshOutcome {
47-
Success = "success",
48-
Failure = "failure",
49-
Logout = "logout",
50-
}
51-
5246
export class FetchHttpApi<O extends IHttpOpts> {
5347
private abortController = new AbortController();
48+
private readonly tokenRefresher: TokenRefresher;
5449

5550
public constructor(
5651
private eventEmitter: TypedEventEmitter<HttpApiEvent, HttpApiEventHandlerMap>,
@@ -59,6 +54,8 @@ export class FetchHttpApi<O extends IHttpOpts> {
5954
checkObjectHasKeys(opts, ["baseUrl", "prefix"]);
6055
opts.onlyData = !!opts.onlyData;
6156
opts.useAuthorizationHeader = opts.useAuthorizationHeader ?? true;
57+
58+
this.tokenRefresher = new TokenRefresher(opts);
6259
}
6360

6461
public abort(): void {
@@ -113,12 +110,6 @@ export class FetchHttpApi<O extends IHttpOpts> {
113110
return this.requestOtherUrl(method, fullUri, body, opts);
114111
}
115112

116-
/**
117-
* Promise used to block authenticated requests during a token refresh to avoid repeated expected errors.
118-
* @private
119-
*/
120-
private tokenRefreshPromise?: Promise<unknown>;
121-
122113
/**
123114
* Perform an authorised request to the homeserver.
124115
* @param method - The HTTP method e.g. "GET".
@@ -146,36 +137,45 @@ export class FetchHttpApi<O extends IHttpOpts> {
146137
* @returns Rejects with an error if a problem occurred.
147138
* This includes network problems and Matrix-specific error JSON.
148139
*/
149-
public async authedRequest<T>(
140+
public authedRequest<T>(
150141
method: Method,
151142
path: string,
152-
queryParams?: QueryDict,
143+
queryParams: QueryDict = {},
153144
body?: Body,
154-
paramOpts: IRequestOpts & { doNotAttemptTokenRefresh?: boolean } = {},
145+
paramOpts: IRequestOpts = {},
155146
): Promise<ResponseType<T, O>> {
156-
if (!queryParams) queryParams = {};
147+
return this.doAuthedRequest<T>(1, method, path, queryParams, body, paramOpts);
148+
}
157149

150+
// Wrapper around public method authedRequest to allow for tracking retry attempt counts
151+
private async doAuthedRequest<T>(
152+
attempt: number,
153+
method: Method,
154+
path: string,
155+
queryParams: QueryDict,
156+
body?: Body,
157+
paramOpts: IRequestOpts = {},
158+
): Promise<ResponseType<T, O>> {
158159
// avoid mutating paramOpts so they can be used on retry
159-
const opts = { ...paramOpts };
160-
161-
// Await any ongoing token refresh before we build the headers/params
162-
await this.tokenRefreshPromise;
160+
const opts = deepCopy(paramOpts);
161+
// we have to manually copy the abortSignal over as it is not a plain object
162+
opts.abortSignal = paramOpts.abortSignal;
163163

164-
// Take a copy of the access token so we have a record of the token we used for this request if it fails
165-
const accessToken = this.opts.accessToken;
166-
if (accessToken) {
164+
// Take a snapshot of the current token state before we start the request so we can reference it if we error
165+
const requestSnapshot = await this.tokenRefresher.prepareForRequest();
166+
if (requestSnapshot.accessToken) {
167167
if (this.opts.useAuthorizationHeader) {
168168
if (!opts.headers) {
169169
opts.headers = {};
170170
}
171171
if (!opts.headers.Authorization) {
172-
opts.headers.Authorization = `Bearer ${accessToken}`;
172+
opts.headers.Authorization = `Bearer ${requestSnapshot.accessToken}`;
173173
}
174174
if (queryParams.access_token) {
175175
delete queryParams.access_token;
176176
}
177177
} else if (!queryParams.access_token) {
178-
queryParams.access_token = accessToken;
178+
queryParams.access_token = requestSnapshot.accessToken;
179179
}
180180
}
181181

@@ -187,33 +187,19 @@ export class FetchHttpApi<O extends IHttpOpts> {
187187
throw error;
188188
}
189189

190-
if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) {
191-
// If the access token has changed since we started the request, but before we refreshed it,
192-
// then it was refreshed due to another request failing, so retry before refreshing again.
193-
let outcome: TokenRefreshOutcome | null = null;
194-
if (accessToken === this.opts.accessToken) {
195-
const tokenRefreshPromise = this.tryRefreshToken();
196-
this.tokenRefreshPromise = tokenRefreshPromise;
197-
outcome = await tokenRefreshPromise;
198-
}
199-
200-
if (outcome === TokenRefreshOutcome.Success || outcome === null) {
190+
if (error.errcode === "M_UNKNOWN_TOKEN") {
191+
const outcome = await this.tokenRefresher.handleUnknownToken(requestSnapshot, attempt);
192+
if (outcome === TokenRefreshOutcome.Success) {
201193
// if we got a new token retry the request
202-
return this.authedRequest(method, path, queryParams, body, {
203-
...paramOpts,
204-
// Only attempt token refresh once for each failed request
205-
doNotAttemptTokenRefresh: outcome !== null,
206-
});
194+
return this.doAuthedRequest(attempt + 1, method, path, queryParams, body, paramOpts);
207195
}
208196
if (outcome === TokenRefreshOutcome.Failure) {
209197
throw new TokenRefreshError(error);
210198
}
211-
// Fall through to SessionLoggedOut handler below
212-
}
213199

214-
// otherwise continue with error handling
215-
if (error.errcode == "M_UNKNOWN_TOKEN" && !opts?.inhibitLogoutEmit) {
216-
this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error);
200+
if (!opts?.inhibitLogoutEmit) {
201+
this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error);
202+
}
217203
} else if (error.errcode == "M_CONSENT_NOT_GIVEN") {
218204
this.eventEmitter.emit(HttpApiEvent.NoConsent, error.message, error.data.consent_uri);
219205
}
@@ -222,33 +208,6 @@ export class FetchHttpApi<O extends IHttpOpts> {
222208
}
223209
}
224210

225-
/**
226-
* Attempt to refresh access tokens.
227-
* On success, sets new access and refresh tokens in opts.
228-
* @returns Promise that resolves to a boolean - true when token was refreshed successfully
229-
*/
230-
@singleAsyncExecution
231-
private async tryRefreshToken(): Promise<TokenRefreshOutcome> {
232-
if (!this.opts.refreshToken || !this.opts.tokenRefreshFunction) {
233-
return TokenRefreshOutcome.Logout;
234-
}
235-
236-
try {
237-
const { accessToken, refreshToken } = await this.opts.tokenRefreshFunction(this.opts.refreshToken);
238-
this.opts.accessToken = accessToken;
239-
this.opts.refreshToken = refreshToken;
240-
// successfully got new tokens
241-
return TokenRefreshOutcome.Success;
242-
} catch (error) {
243-
this.opts.logger?.warn("Failed to refresh token", error);
244-
// If we get a TokenError or MatrixError, we should log out, otherwise assume transient
245-
if (error instanceof TokenRefreshLogoutError || error instanceof MatrixError) {
246-
return TokenRefreshOutcome.Logout;
247-
}
248-
return TokenRefreshOutcome.Failure;
249-
}
250-
}
251-
252211
/**
253212
* Perform a request to the homeserver without any credentials.
254213
* @param method - The HTTP method e.g. "GET".

src/http-api/interface.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,20 @@ export type Body = Record<string, any> | BodyInit;
2424
* Unencrypted access and (optional) refresh token
2525
*/
2626
export type AccessTokens = {
27+
/**
28+
* The new access token to use for authenticated requests
29+
*/
2730
accessToken: string;
31+
/**
32+
* The new refresh token to use for refreshing tokens, optional
33+
*/
2834
refreshToken?: string;
35+
/**
36+
* Approximate date when the access token will expire, optional
37+
*/
38+
expiry?: Date;
2939
};
40+
3041
/**
3142
* @experimental
3243
* Function that performs token refresh using the given refreshToken.

0 commit comments

Comments
 (0)