@@ -58,6 +58,24 @@ def __init__(
58
58
# Thread-safe cache for user contexts
59
59
self ._user_context_cache : dict [str , CachedUserContext ] = {}
60
60
self ._cache_lock = asyncio .Lock ()
61
+
62
+ # Per-MCP-token locks for refresh operations to prevent race conditions
63
+ self ._refresh_locks : dict [str , asyncio .Lock ] = {}
64
+ self ._refresh_locks_lock = asyncio .Lock ()
65
+
66
+ async def _get_refresh_lock (self , mcp_token : str ) -> asyncio .Lock :
67
+ """Get or create a refresh lock for the given MCP token.
68
+
69
+ Args:
70
+ mcp_token: MCP token to get lock for
71
+
72
+ Returns:
73
+ Lock specific to this MCP token
74
+ """
75
+ async with self ._refresh_locks_lock :
76
+ if mcp_token not in self ._refresh_locks :
77
+ self ._refresh_locks [mcp_token ] = asyncio .Lock ()
78
+ return self ._refresh_locks [mcp_token ]
61
79
62
80
async def _get_cached_user_context (self , mcp_token : str ) -> UserContext | None :
63
81
"""Get user context from cache if valid and not expired.
@@ -130,30 +148,43 @@ async def _attempt_token_refresh(self, mcp_token: str, external_token: str) -> s
130
148
logger .warning ("No OAuth server available for token refresh" )
131
149
return None
132
150
133
- try :
134
- # Invalidate cache for the expired MCP token
135
- async with self ._cache_lock :
136
- # Remove cached entry for this MCP token
137
- if mcp_token in self ._user_context_cache :
138
- del self ._user_context_cache [mcp_token ]
139
- logger .debug (f"🗑️ Removed expired token from cache: { mcp_token [:20 ]} ..." )
140
-
141
- # Attempt refresh through the OAuth server
142
- new_external_token = await self .oauth_server .refresh_external_token (mcp_token )
151
+ # Get the refresh lock for this specific MCP token to prevent race conditions
152
+ refresh_lock = await self ._get_refresh_lock (mcp_token )
153
+
154
+ async with refresh_lock :
155
+ try :
156
+ # Check if another request already refreshed the token
157
+ # by checking if we have a valid cached user context now
158
+ cached_context = await self ._get_cached_user_context (mcp_token )
159
+ if cached_context is not None :
160
+ logger .debug (f"🎯 Token already refreshed by another request for { mcp_token [:20 ]} ..." )
161
+ # Get the current external token from token mapping
162
+ current_external_token = self .oauth_server ._token_mapping .get (mcp_token )
163
+ return current_external_token
164
+
165
+ # Invalidate cache for the expired MCP token
166
+ async with self ._cache_lock :
167
+ # Remove cached entry for this MCP token
168
+ if mcp_token in self ._user_context_cache :
169
+ del self ._user_context_cache [mcp_token ]
170
+ logger .debug (f"🗑️ Removed expired token from cache: { mcp_token [:20 ]} ..." )
171
+
172
+ # Attempt refresh through the OAuth server
173
+ new_external_token = await self .oauth_server .refresh_external_token (mcp_token )
174
+
175
+ if new_external_token :
176
+ logger .info (
177
+ f"🔄 Successfully refreshed external token: { new_external_token [:20 ]} ..."
178
+ )
179
+ return new_external_token
180
+ else :
181
+ logger .warning ("Token refresh returned no new token" )
182
+ return None
143
183
144
- if new_external_token :
145
- logger .info (
146
- f"🔄 Successfully refreshed external token: { new_external_token [:20 ]} ..."
147
- )
148
- return new_external_token
149
- else :
150
- logger .warning ("Token refresh returned no new token" )
184
+ except Exception as e :
185
+ logger .error (f"Error during token refresh: { e } " )
151
186
return None
152
187
153
- except Exception as e :
154
- logger .error (f"Error during token refresh: { e } " )
155
- return None
156
-
157
188
async def check_authentication (self ) -> UserContext | None :
158
189
"""Check if the current request is authenticated.
159
190
0 commit comments