Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f4249ea
Fix OAuth redirect URI validation by using AnyUrl instead of AnyHttpUrl
bgaidioz Sep 9, 2025
8c8dc1c
Addressed comment
bgaidioz Sep 9, 2025
4debe2f
Addressed comments
bgaidioz Sep 9, 2025
28b1fd3
Addressed comments
bgaidioz Sep 10, 2025
fa99187
Addressed comments
bgaidioz Sep 10, 2025
ba24b2a
feat: implement user context caching in OAuth middleware
bgaidioz Sep 10, 2025
9bef9e0
refactor: adjust OAuth cache logging levels
bgaidioz Sep 10, 2025
7b22e04
feat: add configurable cache TTL for OAuth user context caching
bgaidioz Sep 10, 2025
3c1644b
Added doc
bgaidioz Sep 11, 2025
a74d83d
add refresh token fields to types and persistence
bgaidioz Sep 10, 2025
6b2fadb
add refresh token support to keycloak provider
bgaidioz Sep 10, 2025
2abe60b
add refresh token support to all oauth providers
bgaidioz Sep 10, 2025
181b25a
add refresh token management to base oauth system
bgaidioz Sep 10, 2025
1ec8a83
add refresh token logic to middleware for expired tokens
bgaidioz Sep 10, 2025
c95dda5
add database migration for refresh_token column
bgaidioz Sep 10, 2025
6064174
fix refresh token lookup to use external token key
bgaidioz Sep 10, 2025
83aa5d9
fix: use MCP token as cache key instead of external token
bgaidioz Sep 10, 2025
dc492d4
fix: prevent race conditions in token refresh with per-token locks
bgaidioz Sep 10, 2025
ae7a42b
fix: improve error handling with proper HTTPException type checking
bgaidioz Sep 10, 2025
143f7aa
fix: standardize OAuth logging levels for production readiness
bgaidioz Sep 10, 2025
5d42b9c
fixup! fix: standardize OAuth logging levels for production readiness
bgaidioz Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/guides/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,42 @@ projects:
# ... other GitHub config
```

## User Context Caching

MXCP caches user context information to improve performance and reduce load on OAuth providers. When a user is authenticated, their user information (username, email, provider details) is cached to avoid making API calls to the OAuth provider on every tool execution.

### Cache TTL Configuration

You can configure the cache duration using the `cache_ttl` setting:

```yaml
projects:
my_project:
profiles:
dev:
auth:
provider: github
cache_ttl: 300 # Cache user context for 5 minutes (default: 300 seconds)

clients:
- client_id: "${CLIENT_ID}"
# ... client config
```

**Configuration Options:**

- `cache_ttl`: Cache duration in seconds for user context information
- **Default**: 300 seconds (5 minutes)
- **Purpose**: Reduces API calls to OAuth providers, improving performance and avoiding rate limits
- **Range**: Any positive integer (recommended: 60-1800 seconds)

**Security Considerations:**

- Cached user context expires automatically after the TTL period
- User information is cached in memory only (not persisted to disk)
- Cache is cleared when the server restarts
- Shorter TTL values provide more up-to-date user information but increase API calls

## Authorization Configuration

MXCP supports configurable scope-based authorization to control access to your endpoints and tools. You can specify which OAuth scopes are required for accessing your server's resources.
Expand Down
4 changes: 3 additions & 1 deletion src/mxcp/sdk/auth/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class AuthConfig(TypedDict, total=False):
"""

provider: Literal["none", "github", "atlassian", "salesforce", "keycloak", "google"] | None
cache_ttl: int | None # Cache TTL in seconds for user context caching
clients: list[OAuthClientConfig] | None # Pre-configured OAuth clients
authorization: AuthorizationConfig | None # Authorization policies
persistence: AuthPersistenceConfig | None # Token/client persistence
Expand All @@ -144,8 +145,9 @@ class ExternalUserInfo:

id: str
scopes: list[str]
raw_token: str # original token from the IdP (JWT or opaque)
raw_token: str # original access token from the IdP (JWT or opaque)
provider: str
refresh_token: str | None = None # refresh token for renewing access tokens


@dataclass
Expand Down
140 changes: 122 additions & 18 deletions src/mxcp/sdk/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ async def _load_clients_from_persistence(self) -> None:
persisted_clients = await self.persistence.list_clients()
for client_data in persisted_clients:
try:
# Convert string URLs back to AnyHttpUrl objects for OAuthClientInformationFull
# Convert string URLs back to AnyUrl objects for OAuthClientInformationFull

redirect_uris_pydantic = []

# Validate each redirect URI individually
for uri in client_data.redirect_uris:
try:
redirect_uris_pydantic.append(AnyHttpUrl(uri))
redirect_uris_pydantic.append(AnyUrl(uri))
except ValidationError as ve:
logger.warning(
f"Skipping malformed redirect URI for client {client_data.client_id}: {uri} - {ve}"
Expand All @@ -185,9 +185,7 @@ async def _load_clients_from_persistence(self) -> None:
client = OAuthClientInformationFull(
client_id=client_data.client_id,
client_secret=client_data.client_secret,
redirect_uris=cast(
list[AnyUrl], redirect_uris_pydantic
), # Use validated URIs
redirect_uris=redirect_uris_pydantic,
grant_types=cast(
list[Literal["authorization_code", "refresh_token"]],
client_data.grant_types,
Expand All @@ -214,7 +212,22 @@ async def _register_configured_clients(self, auth_config: AuthConfig) -> None:
for client_config in clients:
client_id = client_config["client_id"]
redirect_uris_str = client_config.get("redirect_uris", [])
redirect_uris_any = [cast(AnyUrl, uri) for uri in (redirect_uris_str or [])]

# Validate each redirect URI individually
redirect_uris_any = []
for uri in redirect_uris_str or []:
try:
redirect_uris_any.append(AnyUrl(uri))
except ValidationError as ve:
logger.warning(
f"Skipping malformed redirect URI in config for client {client_id}: {uri} - {ve}"
)
# Skip malformed URIs but continue loading the client

# Skip client if no valid redirect URIs remain
if not redirect_uris_any and redirect_uris_str:
logger.error(f"Skipping configured client {client_id}: no valid redirect URIs")
continue

client = OAuthClientInformationFull(
client_id=client_id,
Expand Down Expand Up @@ -244,7 +257,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
# First check memory cache
client = self._clients.get(client_id)
if client:
logger.info(f"Looking up client_id: {client_id}, found in memory cache")
logger.debug(f"Looking up client_id: {client_id}, found in memory cache")
return client

# If not in cache and persistence is available, check persistence
Expand All @@ -253,14 +266,14 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
persisted_client = await self.persistence.load_client(client_id)
if persisted_client:
# Load into memory cache
# Convert string URLs back to AnyHttpUrl objects for OAuthClientInformationFull
# Convert string URLs back to AnyUrl objects for OAuthClientInformationFull

redirect_uris_pydantic = []

# Validate each redirect URI individually
for uri in persisted_client.redirect_uris:
try:
redirect_uris_pydantic.append(AnyHttpUrl(uri))
redirect_uris_pydantic.append(AnyUrl(uri))
except ValidationError as ve:
logger.warning(
f"Skipping malformed redirect URI for client {client_id}: {uri} - {ve}"
Expand All @@ -275,9 +288,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
client = OAuthClientInformationFull(
client_id=persisted_client.client_id,
client_secret=persisted_client.client_secret,
redirect_uris=cast(
list[AnyUrl], redirect_uris_pydantic
), # Use validated URIs
redirect_uris=redirect_uris_pydantic,
grant_types=cast(
list[Literal["authorization_code", "refresh_token"]],
persisted_client.grant_types,
Expand All @@ -289,7 +300,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
client_name=persisted_client.client_name,
)
self._clients[client_id] = client
logger.info(f"Looking up client_id: {client_id}, found in persistence")
logger.debug(f"Looking up client_id: {client_id}, found in persistence")
return client
except Exception as e:
logger.error(f"Error loading client from persistence: {e}")
Expand All @@ -307,13 +318,13 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None
# Store in persistence if available
if self.persistence:
try:
# Convert Pydantic AnyHttpUrl objects to strings for JSON serialization
# Convert Pydantic AnyUrl objects to strings for JSON serialization
redirect_uris_str = [str(uri) for uri in client_info.redirect_uris]

persisted_client = PersistedClient(
client_id=client_info.client_id,
client_secret=client_info.client_secret,
redirect_uris=redirect_uris_str, # Convert AnyHttpUrl to strings
redirect_uris=redirect_uris_str, # Convert AnyUrl to strings
grant_types=cast(list[str], client_info.grant_types),
response_types=cast(list[str], client_info.response_types),
scope=client_info.scope or "",
Expand All @@ -337,12 +348,21 @@ async def register_client_dynamically(self, client_metadata: dict[str, Any]) ->
client_secret = secrets.token_urlsafe(64)

# Extract and validate metadata
redirect_uris = client_metadata.get("redirect_uris", [])
redirect_uris_raw = client_metadata.get("redirect_uris", [])
grant_types = client_metadata.get("grant_types", ["authorization_code"])
response_types = client_metadata.get("response_types", ["code"])
scope = client_metadata.get("scope", "mxcp:access")
client_name = client_metadata.get("client_name", "MCP Client")

# Validate redirect URIs
redirect_uris = []
for uri in redirect_uris_raw:
try:
redirect_uris.append(AnyUrl(uri))
except ValidationError as ve:
logger.error(f"Invalid redirect URI in dynamic registration: {uri} - {ve}")
raise HTTPException(400, f"Invalid redirect URI: {uri}") from ve

# Create client object
client_info = OAuthClientInformationFull(
client_id=client_id,
Expand Down Expand Up @@ -396,6 +416,7 @@ async def _store_token(
scopes: list[str],
expires_in: int | None,
external_token: str | None = None,
refresh_token: str | None = None,
) -> None:
expires_at = (time.time() + expires_in) if expires_in else None
access_token = AccessToken(
Expand All @@ -419,6 +440,7 @@ async def _store_token(
token=token,
client_id=client_id,
external_token=external_token,
refresh_token=refresh_token,
scopes=scopes,
expires_at=expires_at,
created_at=time.time(),
Expand Down Expand Up @@ -481,7 +503,12 @@ async def handle_callback(self, code: str, state: str) -> str:

# Store external token (temporary until exchanged for MCP token)
await self._store_token(
user_info.raw_token, meta.client_id, user_info.scopes, None, user_info.raw_token
user_info.raw_token,
meta.client_id,
user_info.scopes,
None,
user_info.raw_token,
user_info.refresh_token,
)
self._token_mapping[mcp_code] = user_info.raw_token

Expand Down Expand Up @@ -530,7 +557,7 @@ async def load_authorization_code(
code_challenge=persisted_code.code_challenge or "",
)
self._auth_codes[code] = auth_code
logger.info(f"Loaded auth code from persistence: {code}")
logger.debug(f"Loaded auth code from persistence: {code}")
except Exception as e:
logger.error(f"Error loading auth code from persistence: {e}")

Expand Down Expand Up @@ -680,3 +707,80 @@ async def revoke_token(self, token: str, token_type_hint: str | None = None) ->
logger.debug(f"Revoked token from persistence: {token[:10]}...")
except Exception as e:
logger.error(f"Error revoking token from persistence: {e}")

async def refresh_external_token(self, mcp_token: str) -> str | None:
"""Refresh an expired external token using its refresh token.

Args:
mcp_token: The MCP token whose external token needs refreshing

Returns:
New external access token if refresh successful, None if failed
"""
async with self._lock:
# Get the current external token
external_token = self._token_mapping.get(mcp_token)
if not external_token:
logger.warning(
f"No external token mapping found for MCP token: {mcp_token[:10]}..."
)
return None

# Load the persisted token to get the refresh token
if not self.persistence:
logger.warning("No persistence backend available for refresh token lookup")
return None

try:
# Look up the refresh token using the external token, not the MCP token
persisted_token = await self.persistence.load_token(external_token)
if not persisted_token or not persisted_token.refresh_token:
logger.warning(
f"No refresh token available for external token: {external_token[:20]}..."
)
return None

# Call the provider's refresh method
if not hasattr(self.handler, "refresh_access_token"):
logger.warning(
f"Provider {type(self.handler).__name__} does not support token refresh"
)
return None

refresh_response = await self.handler.refresh_access_token(
persisted_token.refresh_token
)

# Extract new tokens
new_access_token = refresh_response.get("access_token")
new_refresh_token = refresh_response.get(
"refresh_token", persisted_token.refresh_token
)

if not new_access_token or not isinstance(new_access_token, str):
logger.error("No access token received from refresh")
return None

# Update token mappings
self._token_mapping[mcp_token] = new_access_token

# Update persistence with new tokens
updated_token = PersistedAccessToken(
token=persisted_token.token,
client_id=persisted_token.client_id,
external_token=new_access_token,
refresh_token=new_refresh_token,
scopes=persisted_token.scopes,
expires_at=persisted_token.expires_at, # Could update with new expiry if provided
created_at=persisted_token.created_at,
)
await self.persistence.store_token(updated_token)

logger.info(
f"Successfully refreshed external token for MCP token: {mcp_token[:10]}..."
)
return str(new_access_token)

except Exception as e:
logger.error(f"Error refreshing external token: {e}")
return None
Loading
Loading