From 34572cb02fc71ce0a7cbf02674df75f9964447a9 Mon Sep 17 00:00:00 2001 From: Shamsul Arefin Date: Sat, 16 Aug 2025 14:51:47 +0500 Subject: [PATCH 01/21] Oauth 2.1 design Signed-off-by: Shamsul Arefin --- .../architecture/oauth-21-unified-design.md | 1565 +++++++++++++++++ 1 file changed, 1565 insertions(+) create mode 100644 docs/docs/architecture/oauth-21-unified-design.md diff --git a/docs/docs/architecture/oauth-21-unified-design.md b/docs/docs/architecture/oauth-21-unified-design.md new file mode 100644 index 00000000..cc15b375 --- /dev/null +++ b/docs/docs/architecture/oauth-21-unified-design.md @@ -0,0 +1,1565 @@ +# OAuth 2.1 Integration Design for MCP Gateway + +**Version**: 3.0 (Unified) +**Status**: Draft +**Author**: MCP Gateway Team +**Date**: December 2024 + +## Table of Contents + +1. [Executive Summary](#executive-summary) +2. [Quick Reference: OAuth 2.1 Changes](#quick-reference-oauth-21-changes) +3. [Motivation](#motivation) +4. [Architecture Overview](#architecture-overview) +5. [OAuth 2.1 Specific Requirements](#oauth-21-specific-requirements) +6. [Database Schema Design](#database-schema-design) +7. [Component Design](#component-design) +8. [Implementation Flow](#implementation-flow) +9. [Key Implementation Details](#key-implementation-details) +10. [Security Considerations](#security-considerations) +11. [Configuration](#configuration) +12. [Migration Strategy](#migration-strategy) +13. [Testing Strategy](#testing-strategy) +14. [Rollout Plan](#rollout-plan) +15. [Dependencies](#dependencies) +16. [Example Usage](#example-usage) +17. [Monitoring and Observability](#monitoring-and-observability) +18. [Security Best Practices](#security-best-practices) +19. [Future Enhancements](#future-enhancements) +20. [Conclusion](#conclusion) + +## Executive Summary + +This document provides a comprehensive design for integrating OAuth 2.1 authentication into the MCP Gateway, enabling agents to perform actions on behalf of users without requiring personal access tokens (PATs). The implementation adheres to OAuth 2.1's enhanced security standards, including mandatory PKCE for all clients, strict redirect URI matching, one-time refresh tokens, and prohibition of bearer tokens in URLs. + +## Quick Reference: OAuth 2.1 Changes + +### Key Differences from OAuth 2.0 + +| Feature | OAuth 2.0 | OAuth 2.1 | Impact | +|---------|-----------|-----------|---------| +| **PKCE** | Optional for confidential clients | **Mandatory for ALL clients** | Must implement code_verifier/challenge | +| **Implicit Flow** | Supported | **Completely removed** | Use Authorization Code + PKCE | +| **Resource Owner Password** | Supported | **Strongly discouraged** | Avoid implementation | +| **Redirect URI Matching** | Partial matches allowed | **Exact string match only** | No wildcards permitted | +| **Refresh Tokens** | Can be reused | **One-time use only** | Automatic rotation required | +| **Bearer Tokens in URLs** | Allowed | **Prohibited** | Must use Authorization header | + +### Implementation Checklist + +- [ ] Implement PKCE with S256 for all authorization code flows +- [ ] Remove support for implicit grant flow +- [ ] Implement exact redirect URI validation +- [ ] Add refresh token rotation with immediate invalidation +- [ ] Validate no bearer tokens in URLs +- [ ] Update database schema for OAuth 2.1 requirements +- [ ] Implement OAuth Manager with PKCE support +- [ ] Create Token Cache Manager with rotation +- [ ] Update Admin UI for OAuth 2.1 configuration +- [ ] Modify gateway and tool services for OAuth 2.1 + +## Motivation + +Current limitations of MCP Gateway authentication: + +1. **Security Risk**: Personal Access Tokens (PATs) provide broad access and must be carefully managed +2. **User Experience**: Users must manually create and manage tokens for each service +3. **Scalability**: Managing multiple PATs across different services becomes cumbersome +4. **Delegation**: No native support for agents acting on behalf of users with scoped permissions + +OAuth 2.1 addresses these concerns by providing: +- Enhanced security with mandatory PKCE for all clients +- Removal of vulnerable flows (implicit, ROPC) +- Scoped access control with principle of least privilege +- Secure token refresh with mandatory rotation +- Better security through short-lived access tokens +- Prohibition of bearer tokens in URLs + +## Architecture Overview + +```mermaid +graph TD + subgraph "Admin UI Layer" + A[Gateway Configuration Form] + B[OAuth 2.1 Configuration Fields] + C[Token Management Interface] + end + + subgraph "Database Layer" + D[Gateway Table] + E[OAuth Credentials Table] + F[OAuth Token Cache Table] + G[OAuth PKCE State Table] + end + + subgraph "Service Layer" + H[Gateway Service] + I[Tool Service] + J[OAuth Manager] + K[Token Cache Manager] + L[PKCE Manager] + end + + subgraph "External Systems" + M[OAuth Provider] + N[MCP Server] + end + + A --> B + B --> D + B --> E + + H --> J + I --> J + J --> K + J --> L + J --> M + + K --> F + L --> G + H --> N + I --> N + + J -.->|Uses| O[oauthlib/authlib] +``` + +## OAuth 2.1 Specific Requirements + +### Implementation Challenges + +1. **Library Support**: Not all OAuth libraries fully support OAuth 2.1 yet + - `oauthlib` may need patches or extensions + - Consider `authlib` which has better OAuth 2.1 support + - May need to implement custom PKCE handling + +2. **Provider Compatibility**: Some OAuth providers may not fully support OAuth 2.1 + - GitHub, Google, and modern providers generally support PKCE + - Legacy enterprise systems may need adaptation layer + +3. **Breaking Changes**: OAuth 2.1 is not backward compatible + - Existing OAuth 2.0 integrations will need updates + - Cannot support both OAuth 2.0 and 2.1 simultaneously for same flow + +4. **Performance Considerations**: + - Token rotation adds database operations + - PKCE adds computational overhead (minimal) + - State management requires additional storage + +### Mandatory Requirements + +1. **PKCE (Proof Key for Code Exchange)** + - Required for ALL OAuth clients, including confidential clients + - Must use S256 (SHA-256) challenge method + - Code verifier length: 43-128 characters + +2. **Grant Type Restrictions** + - Only Authorization Code (with PKCE) and Client Credentials allowed + - Implicit Grant flow completely removed + - Resource Owner Password Credentials strongly discouraged + +3. **Security Enhancements** + - Exact redirect URI matching (no wildcards) + - One-time use refresh tokens with rotation + - No bearer tokens in URL query strings or fragments + - Mandatory HTTPS for all OAuth endpoints + +## Database Schema Design + +### Complete Schema Overview + +```mermaid +erDiagram + GATEWAY { + string id PK + string name UK + string url + string transport "SSE|STREAMABLEHTTP" + string auth_type "basic|bearer|oauth|authheaders" + json auth_value "encrypted" + json oauth_config "OAuth 2.1 settings" + boolean enabled + boolean reachable + datetime created_at + datetime updated_at + } + + OAUTH_CREDENTIALS { + string id PK + string gateway_id FK + string client_id + string client_secret "encrypted" + string authorization_url + string token_url + string redirect_uri "NOT NULL - exact match" + json scopes + string grant_type "authorization_code|client_credentials" + boolean pkce_required "default true" + string code_challenge_method "S256 only" + datetime created_at + datetime updated_at + } + + OAUTH_TOKEN_CACHE { + string id PK + string gateway_id FK + string access_token "encrypted" + string refresh_token "encrypted" + datetime expires_at + boolean is_active "for rotation tracking" + string previous_refresh_token_id FK + json token_metadata + datetime created_at + datetime used_at + datetime invalidated_at + } + + OAUTH_PKCE_STATE { + string id PK + string gateway_id FK + string state UK "cryptographically secure" + string code_verifier "encrypted" + string code_challenge + string redirect_uri "must match exactly" + string nonce "for replay protection" + datetime expires_at + datetime created_at + } + + TOOL { + string id PK + string gateway_id FK + string name + string auth_type + string auth_value "encrypted" + } + + GATEWAY ||--o| OAUTH_CREDENTIALS : has + GATEWAY ||--o{ OAUTH_TOKEN_CACHE : uses + GATEWAY ||--o{ OAUTH_PKCE_STATE : tracks + GATEWAY ||--o{ TOOL : provides + OAUTH_TOKEN_CACHE ||--o| OAUTH_TOKEN_CACHE : replaces +``` + +### SQL Schema Definitions + +```sql +-- OAuth Credentials Table +CREATE TABLE oauth_credentials ( + id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), + gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, + client_id VARCHAR(255) NOT NULL, + client_secret TEXT, -- Encrypted with AES-256-GCM + authorization_url TEXT, + token_url TEXT NOT NULL, + redirect_uri TEXT NOT NULL, -- OAuth 2.1: Exact match required + scopes JSON DEFAULT '[]', + grant_type VARCHAR(50) DEFAULT 'authorization_code', + pkce_required BOOLEAN DEFAULT TRUE, -- OAuth 2.1: Always true + code_challenge_method VARCHAR(10) DEFAULT 'S256', -- OAuth 2.1: Only S256 + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(gateway_id), + CHECK (grant_type IN ('authorization_code', 'client_credentials')), + CHECK (code_challenge_method = 'S256') +); + +-- PKCE State Table +CREATE TABLE oauth_pkce_state ( + id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), + gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, + state VARCHAR(255) UNIQUE NOT NULL, + code_verifier VARCHAR(128) NOT NULL, -- Encrypted + code_challenge VARCHAR(128) NOT NULL, + redirect_uri TEXT NOT NULL, + nonce VARCHAR(255), -- For OpenID Connect + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + INDEX idx_state_expires (state, expires_at), + INDEX idx_gateway_state (gateway_id, state) +); + +-- Token Cache Table with Rotation Support +CREATE TABLE oauth_token_cache ( + id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), + gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, + access_token TEXT NOT NULL, -- Encrypted + refresh_token TEXT, -- Encrypted + expires_at TIMESTAMP WITH TIME ZONE, + is_active BOOLEAN DEFAULT TRUE, + previous_refresh_token_id VARCHAR(36) REFERENCES oauth_token_cache(id), + token_metadata JSON, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + used_at TIMESTAMP WITH TIME ZONE, + invalidated_at TIMESTAMP WITH TIME ZONE, + INDEX idx_gateway_active (gateway_id, is_active), + INDEX idx_expires (expires_at), + INDEX idx_refresh_token (refresh_token) -- For rotation lookup +); +``` + +## Component Design + +### 1. OAuth Manager Service + +**Location**: `mcpgateway/services/oauth_manager.py` + +```python +from oauthlib.oauth2 import BackendApplicationClient, WebApplicationClient +from requests_oauthlib import OAuth2Session +from typing import Optional, Dict, Any, Tuple +import secrets +import hashlib +import base64 +import time + +class OAuthManager: + """Manages OAuth 2.1 authentication flows with mandatory PKCE.""" + + def __init__(self, cache_manager: TokenCacheManager, db_session): + self.cache_manager = cache_manager + self.db = db_session + + async def get_access_token( + self, + gateway_id: str, + credentials: OAuthCredentials + ) -> str: + """Get valid access token, refreshing if necessary.""" + # Check cache first + cached_token = await self.cache_manager.get_token(gateway_id) + if cached_token and not self._is_expired(cached_token): + return cached_token.access_token + + # Token expired or not found + if cached_token and cached_token.refresh_token: + return await self.refresh_token( + cached_token.refresh_token, + credentials + ) + else: + # Need new authorization + return await self.initiate_authorization(credentials) + + async def initiate_authorization( + self, + credentials: OAuthCredentials + ) -> Dict[str, str]: + """Start OAuth 2.1 authorization with PKCE.""" + # Generate PKCE parameters + code_verifier, code_challenge = self.generate_pkce_pair() + + # Generate secure state + state = secrets.token_urlsafe(32) + + # Store PKCE state + await self._store_pkce_state( + credentials.gateway_id, + state, + code_verifier, + code_challenge, + credentials.redirect_uri + ) + + # Build authorization URL + auth_params = { + 'response_type': 'code', + 'client_id': credentials.client_id, + 'redirect_uri': credentials.redirect_uri, + 'scope': ' '.join(credentials.scopes), + 'state': state, + 'code_challenge': code_challenge, + 'code_challenge_method': 'S256' + } + + return { + 'authorization_url': self._build_auth_url( + credentials.authorization_url, + auth_params + ), + 'state': state + } + + async def exchange_code( + self, + code: str, + state: str, + credentials: OAuthCredentials + ) -> TokenResponse: + """Exchange authorization code for tokens with PKCE verification.""" + # Retrieve and validate PKCE state + pkce_state = await self._get_pkce_state(state) + if not pkce_state: + raise ValueError("Invalid or expired state") + + # Validate redirect URI exact match + if pkce_state.redirect_uri != credentials.redirect_uri: + raise ValueError("OAuth 2.1: Redirect URI must match exactly") + + # Exchange code for tokens + token_data = { + 'grant_type': 'authorization_code', + 'code': code, + 'redirect_uri': credentials.redirect_uri, + 'code_verifier': pkce_state.code_verifier, + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret + } + + response = await self._request_token( + credentials.token_url, + token_data + ) + + # Store tokens with rotation tracking + await self.cache_manager.store_token( + credentials.gateway_id, + response, + invalidate_previous=True + ) + + # Clean up PKCE state + await self._delete_pkce_state(state) + + return response + + async def refresh_token( + self, + refresh_token: str, + credentials: OAuthCredentials + ) -> TokenResponse: + """OAuth 2.1: Refresh with mandatory rotation.""" + # Immediately invalidate old refresh token + await self.cache_manager.invalidate_refresh_token(refresh_token) + + token_data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret + } + + response = await self._request_token( + credentials.token_url, + token_data + ) + + # Store new tokens, invalidating all previous + await self.cache_manager.store_token( + credentials.gateway_id, + response, + invalidate_previous=True + ) + + return response + + def generate_pkce_pair(self) -> Tuple[str, str]: + """Generate PKCE code verifier and challenge (S256).""" + # Generate cryptographically secure verifier + code_verifier = base64.urlsafe_b64encode( + secrets.token_bytes(64) # 96 chars after encoding + ).decode('utf-8').rstrip('=') + + # Create S256 challenge + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode('utf-8')).digest() + ).decode('utf-8').rstrip('=') + + return code_verifier, code_challenge + + def validate_redirect_uri( + self, + requested_uri: str, + registered_uri: str + ) -> bool: + """OAuth 2.1: Exact string match for redirect URIs.""" + return requested_uri == registered_uri + + def validate_token_url(self, url: str) -> bool: + """OAuth 2.1: Ensure no tokens in URLs.""" + forbidden_params = ['access_token', 'refresh_token', 'token'] + parsed = urlparse(url) + + # Check query string + if parsed.query: + params = parse_qs(parsed.query) + for forbidden in forbidden_params: + if forbidden in params: + return False + + # Check fragment + if parsed.fragment: + for forbidden in forbidden_params: + if forbidden in parsed.fragment: + return False + + return True +``` + +### 2. Token Cache Manager + +**Location**: `mcpgateway/services/token_cache_manager.py` + +```python +from cryptography.fernet import Fernet +from datetime import datetime, timedelta, timezone +import json + +class TokenCacheManager: + """OAuth 2.1 token cache with rotation and encryption.""" + + def __init__(self, db_session, encryption_key: bytes): + self.db = db_session + self.cipher = Fernet(encryption_key) + + async def get_token(self, gateway_id: str) -> Optional[OAuthToken]: + """Get active token if valid.""" + token = self.db.query(OAuthTokenCache).filter( + OAuthTokenCache.gateway_id == gateway_id, + OAuthTokenCache.is_active == True, + OAuthTokenCache.invalidated_at.is_(None) + ).first() + + if not token: + return None + + # Check expiration + if token.expires_at and token.expires_at < datetime.now(timezone.utc): + await self.invalidate_token(gateway_id) + return None + + # Decrypt tokens + return OAuthToken( + access_token=self._decrypt(token.access_token), + refresh_token=self._decrypt(token.refresh_token) if token.refresh_token else None, + expires_at=token.expires_at, + metadata=token.token_metadata + ) + + async def store_token( + self, + gateway_id: str, + token: TokenResponse, + invalidate_previous: bool = True + ) -> None: + """Store token with OAuth 2.1 rotation support.""" + if invalidate_previous: + # Invalidate all previous tokens for this gateway + await self.invalidate_all_tokens(gateway_id) + + # Calculate expiration + expires_at = datetime.now(timezone.utc) + timedelta( + seconds=token.expires_in + ) if token.expires_in else None + + # Get previous token ID for rotation tracking + previous_token = self.db.query(OAuthTokenCache).filter( + OAuthTokenCache.gateway_id == gateway_id, + OAuthTokenCache.is_active == True + ).first() + + # Create new token entry + new_token = OAuthTokenCache( + gateway_id=gateway_id, + access_token=self._encrypt(token.access_token), + refresh_token=self._encrypt(token.refresh_token) if token.refresh_token else None, + expires_at=expires_at, + is_active=True, + previous_refresh_token_id=previous_token.id if previous_token else None, + token_metadata={ + 'token_type': token.token_type, + 'scope': token.scope, + 'issued_at': datetime.now(timezone.utc).isoformat() + } + ) + + self.db.add(new_token) + self.db.commit() + + async def invalidate_refresh_token(self, refresh_token: str) -> None: + """OAuth 2.1: Immediately invalidate used refresh token.""" + encrypted_token = self._encrypt(refresh_token) + + token_entry = self.db.query(OAuthTokenCache).filter( + OAuthTokenCache.refresh_token == encrypted_token, + OAuthTokenCache.is_active == True + ).first() + + if token_entry: + token_entry.is_active = False + token_entry.invalidated_at = datetime.now(timezone.utc) + self.db.commit() + + async def invalidate_all_tokens(self, gateway_id: str) -> None: + """Invalidate all tokens for rotation.""" + self.db.query(OAuthTokenCache).filter( + OAuthTokenCache.gateway_id == gateway_id, + OAuthTokenCache.is_active == True + ).update({ + 'is_active': False, + 'invalidated_at': datetime.now(timezone.utc) + }) + self.db.commit() + + def _encrypt(self, data: str) -> str: + """Encrypt sensitive data.""" + return self.cipher.encrypt(data.encode()).decode() + + def _decrypt(self, data: str) -> str: + """Decrypt sensitive data.""" + return self.cipher.decrypt(data.encode()).decode() +``` + +### 3. Admin UI Enhancements + +**OAuth 2.1 Configuration Form Fields**: + +```html + + + +``` + +## Implementation Flow + +### OAuth 2.1 Complete Flow + +```mermaid +sequenceDiagram + participant U as User + participant UI as Admin UI + participant GS as Gateway Service + participant OM as OAuth Manager + participant TC as Token Cache + participant PS as PKCE State Store + participant OP as OAuth Provider + participant MCP as MCP Server + + Note over U,MCP: Gateway Configuration + U->>UI: Configure OAuth Gateway + UI->>UI: Validate OAuth 2.1 Requirements + UI->>GS: Save OAuth Config + GS->>GS: Validate redirect URI format + GS->>OM: Initialize OAuth Manager + + Note over U,MCP: Authorization Code Flow with PKCE + U->>GS: Request Tool Access + GS->>TC: Check Token Cache + TC-->>GS: No Valid Token + + GS->>OM: Initiate Authorization + OM->>OM: Generate PKCE Pair + OM->>PS: Store PKCE State + OM-->>GS: Authorization URL + State + GS-->>U: Redirect to OAuth Provider + + U->>OP: Authenticate + OP->>OP: Validate PKCE Challenge + OP-->>U: Authorization Code + U-->>GS: Redirect with Code + State + + GS->>OM: Exchange Code + OM->>PS: Retrieve PKCE Verifier + PS-->>OM: Code Verifier + OM->>OM: Validate State + OM->>OP: Token Request + PKCE Verifier + OP->>OP: Validate PKCE + OP-->>OM: Access + Refresh Tokens + + OM->>TC: Store Tokens (Rotate) + TC->>TC: Invalidate Previous Tokens + OM->>PS: Clean PKCE State + OM-->>GS: Access Token + + Note over U,MCP: Tool Invocation + GS->>GS: Validate No Token in URL + GS->>MCP: Call Tool (Bearer Header) + MCP-->>GS: Tool Response + GS-->>U: Return Result + + Note over U,MCP: Token Refresh with Rotation + U->>GS: Invoke Tool (Later) + GS->>TC: Check Token Cache + TC-->>GS: Token Expired + + GS->>OM: Refresh Token + OM->>TC: Mark Old Token Used + OM->>OP: Refresh Request + OP-->>OM: New Tokens + OM->>TC: Store New Tokens + TC->>TC: Invalidate Old Tokens + OM-->>GS: New Access Token +``` + +## Key Implementation Details + +### 1. Gateway Service Modifications + +**File**: `mcpgateway/services/gateway_service.py` + +```python +async def _initialize_gateway( + self, + url: str, + authentication: Optional[Dict[str, str]] = None, + transport: str = "SSE" +) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: + """Initialize gateway with OAuth 2.1 support.""" + + # OAuth 2.1: Validate URL has no tokens + if self._contains_bearer_token(url): + raise ValueError( + "OAuth 2.1 Security Violation: Bearer tokens must not be passed in URLs. " + "Use Authorization header instead." + ) + + headers = {} + + if authentication and authentication.get('type') == 'oauth': + # Get OAuth credentials + credentials = await self._get_oauth_credentials( + authentication['gateway_id'] + ) + + # Validate redirect URI if present + if hasattr(credentials, 'redirect_uri') and credentials.redirect_uri: + parsed_url = urlparse(url) + parsed_redirect = urlparse(credentials.redirect_uri) + + # OAuth 2.1: Exact match validation + if parsed_url.scheme != parsed_redirect.scheme or \ + parsed_url.netloc != parsed_redirect.netloc or \ + parsed_url.path != parsed_redirect.path: + raise ValueError( + f"OAuth 2.1: Redirect URI mismatch. " + f"Expected: {credentials.redirect_uri}, Got: {url}" + ) + + # Get access token + try: + access_token = await self.oauth_manager.get_access_token( + gateway_id=authentication['gateway_id'], + credentials=credentials + ) + + # OAuth 2.1: Always use Authorization header + headers = { + 'Authorization': f'Bearer {access_token}', + 'X-OAuth-Version': '2.1' # Indicate OAuth 2.1 compliance + } + except TokenExpiredError: + # Handle expired token + logger.warning(f"Token expired for gateway {authentication['gateway_id']}") + raise GatewayConnectionError("OAuth token expired, re-authorization required") + + else: + # Non-OAuth authentication + headers = decode_auth(authentication) + + # Continue with connection... + return await self._connect_to_gateway(url, headers, transport) + +def _contains_bearer_token(self, url: str) -> bool: + """Check if URL contains bearer tokens (OAuth 2.1 violation).""" + parsed = urlparse(url) + + # Check query parameters + if parsed.query: + params = parse_qs(parsed.query.lower()) + token_params = ['access_token', 'bearer', 'token', 'auth'] + for param in token_params: + if param in params: + return True + + # Check fragment + if parsed.fragment: + fragment_lower = parsed.fragment.lower() + if any(token in fragment_lower for token in ['access_token', 'bearer', 'token']): + return True + + return False +``` + +### 2. Tool Service Modifications + +**File**: `mcpgateway/services/tool_service.py` + +```python +async def invoke_tool( + self, + db: Session, + name: str, + arguments: Dict[str, Any], + request_headers: Optional[Dict[str, str]] = None +) -> ToolResult: + """Invoke tool with OAuth 2.1 compliance.""" + + tool = await self.get_tool_by_name(db, name) + + # OAuth 2.1: Pre-flight URL validation + if tool.url and self._contains_bearer_token(tool.url): + raise ToolValidationError( + f"OAuth 2.1 Violation: Tool URL contains bearer token. " + f"Tool: {name}, URL: {tool.url}" + ) + + headers = {} + + if tool.auth_type == 'oauth': + gateway = tool.gateway + if not gateway: + raise ToolError(f"OAuth tool {name} has no associated gateway") + + # Get OAuth credentials + credentials = await self._get_oauth_credentials(db, gateway.id) + + # Token acquisition with retry logic + max_retries = 2 + for attempt in range(max_retries): + try: + # Get access token (handles refresh automatically) + access_token = await self.oauth_manager.get_access_token( + gateway_id=gateway.id, + credentials=credentials + ) + + # OAuth 2.1: Bearer token in Authorization header only + headers = { + 'Authorization': f'Bearer {access_token}', + 'X-Tool-Name': name, + 'X-OAuth-Version': '2.1' + } + + break + + except TokenExpiredError: + if attempt == max_retries - 1: + raise + # Token expired during processing, retry + logger.info(f"Token expired during tool invocation, retrying... (attempt {attempt + 1})") + await asyncio.sleep(0.5) + + except RefreshTokenExpiredError: + # Refresh token expired, need re-authorization + raise ToolError( + f"OAuth refresh token expired for tool {name}. " + "Re-authorization required." + ) + else: + # Non-OAuth authentication + headers = self._get_tool_headers(tool) + + # OAuth 2.1: Final header validation + self._validate_oauth_headers(headers) + + # Merge with request headers (OAuth headers take precedence) + if request_headers: + merged_headers = {**request_headers, **headers} + else: + merged_headers = headers + + # Execute tool with observability + with create_span("tool.invoke.oauth21", { + "tool.name": name, + "tool.auth_type": tool.auth_type, + "oauth.version": "2.1" if tool.auth_type == 'oauth' else None + }) as span: + try: + result = await self._execute_tool( + tool=tool, + arguments=arguments, + headers=merged_headers + ) + + if span: + span.set_attribute("tool.success", True) + + return result + + except Exception as e: + if span: + span.set_attribute("tool.success", False) + span.set_attribute("error.message", str(e)) + raise + +def _validate_oauth_headers(self, headers: Dict[str, str]) -> None: + """Validate OAuth 2.1 header compliance.""" + for key, value in headers.items(): + if key.lower() != 'authorization': + # Check for tokens in other headers + value_lower = value.lower() + if any(token in value_lower for token in ['bearer', 'access_token', 'token']): + logger.warning( + f"Potential OAuth 2.1 violation: Token-like value in header '{key}'. " + "Tokens should only be in Authorization header." + ) +``` + +## Security Considerations + +### OAuth 2.1 Security Requirements + +1. **Token Storage** + - Use AES-256-GCM for all token encryption + - Separate encryption keys for different token types + - Key rotation every 90 days + +2. **PKCE Implementation** + - Code verifier: 43-128 characters (recommended: 96) + - Use cryptographically secure random generation + - S256 challenge method only + +3. **Redirect URI Security** + - Store as exact strings in database + - Validate character-by-character match + - No URL encoding differences allowed + +4. **Refresh Token Rotation** + - Immediate invalidation of used tokens + - Track token lineage for audit + - Maximum rotation chain length: 10 + +5. **Rate Limiting** + - Authorization attempts: 5 per minute per client + - Token requests: 10 per minute per client + - Exponential backoff on failures + +6. **Audit Logging** + ```python + # Required audit events + - Authorization initiated + - Authorization completed/failed + - Token issued + - Token refreshed + - Token invalidated + - Suspicious activity detected + ``` + +## Configuration + +### Environment Variables + +```env +# OAuth 2.1 Configuration +OAUTH_ENABLE=true +OAUTH_VERSION=2.1 + +# Token Management +OAUTH_TOKEN_CACHE_TTL=3600 # Access token cache TTL (seconds) +OAUTH_REFRESH_TOKEN_TTL=604800 # Refresh token validity (7 days) +OAUTH_MAX_TOKEN_AGE=86400 # Maximum token age before forced refresh + +# PKCE Configuration +OAUTH_PKCE_REQUIRED=true # Always true for OAuth 2.1 +OAUTH_PKCE_CODE_LENGTH=96 # Code verifier length (43-128) +OAUTH_PKCE_STATE_TTL=600 # Authorization state TTL (seconds) + +# Security Settings +OAUTH_STRICT_REDIRECT_URI=true # Always true for OAuth 2.1 +OAUTH_REFRESH_TOKEN_ROTATION=true # Always true for OAuth 2.1 +OAUTH_ALLOW_HTTP=false # HTTPS required for OAuth endpoints + +# Rate Limiting +OAUTH_AUTH_RATE_LIMIT=5 # Auth attempts per minute +OAUTH_TOKEN_RATE_LIMIT=10 # Token requests per minute +OAUTH_RATE_LIMIT_WINDOW=60 # Rate limit window (seconds) + +# Encryption +OAUTH_ENCRYPTION_KEY=${OAUTH_ENCRYPTION_KEY} # Base64 encoded 32-byte key +OAUTH_KEY_ROTATION_DAYS=90 # Encryption key rotation period + +# Monitoring +OAUTH_ENABLE_METRICS=true +OAUTH_ENABLE_AUDIT_LOG=true +OAUTH_AUDIT_LOG_RETENTION_DAYS=90 +``` + +## Migration Strategy + +### Phase 1: Database Schema (Week 1) + +```sql +-- Migration script +BEGIN TRANSACTION; + +-- Add OAuth config to gateways +ALTER TABLE gateways +ADD COLUMN oauth_config JSONB, +ADD COLUMN oauth_version VARCHAR(10) DEFAULT '2.1'; + +-- Create OAuth tables +CREATE TABLE oauth_credentials (...); +CREATE TABLE oauth_token_cache (...); +CREATE TABLE oauth_pkce_state (...); + +-- Add indexes +CREATE INDEX idx_oauth_active_tokens ON oauth_token_cache(gateway_id, is_active); +CREATE INDEX idx_oauth_pkce_expiry ON oauth_pkce_state(expires_at); + +-- Add constraints +ALTER TABLE oauth_credentials +ADD CONSTRAINT chk_oauth21_grant_type +CHECK (grant_type IN ('authorization_code', 'client_credentials')); + +COMMIT; +``` + +### Phase 2: Service Implementation (Week 2-3) + +1. Implement OAuth Manager with PKCE +2. Implement Token Cache Manager with rotation +3. Add OAuth support to Gateway Service +4. Add OAuth support to Tool Service +5. Create OAuth callback endpoints + +### Phase 3: UI Integration (Week 4) + +1. Update Admin UI with OAuth 2.1 forms +2. Add OAuth status dashboard +3. Implement authorization flow UI +4. Add token management interface + +### Phase 4: Testing & Rollout (Week 5-6) + +1. Unit tests for all OAuth components +2. Integration tests for complete flows +3. Security penetration testing +4. Performance testing under load +5. Documentation and training + +## Testing Strategy + +### Unit Tests + +```python +# Test PKCE generation +def test_pkce_generation(): + manager = OAuthManager() + verifier, challenge = manager.generate_pkce_pair() + + assert 43 <= len(verifier) <= 128 + assert len(challenge) == 43 # S256 produces 43 chars + + # Verify challenge calculation + expected = base64.urlsafe_b64encode( + hashlib.sha256(verifier.encode()).digest() + ).decode().rstrip('=') + + assert challenge == expected + +# Test redirect URI validation +def test_redirect_uri_validation(): + manager = OAuthManager() + + # Exact match - should pass + assert manager.validate_redirect_uri( + "https://app.com/callback", + "https://app.com/callback" + ) + + # Different path - should fail + assert not manager.validate_redirect_uri( + "https://app.com/callback/auth", + "https://app.com/callback" + ) + + # Different scheme - should fail + assert not manager.validate_redirect_uri( + "http://app.com/callback", + "https://app.com/callback" + ) + +# Test token rotation +async def test_refresh_token_rotation(): + cache = TokenCacheManager() + + # Store initial token + await cache.store_token("gw1", initial_token) + + # Refresh should invalidate old token + await cache.invalidate_refresh_token(initial_token.refresh_token) + + # Old token should be inactive + old_token = await cache.get_token("gw1") + assert old_token is None +``` + +### Integration Tests + +```python +# Complete OAuth 2.1 flow test +async def test_oauth21_complete_flow(): + # Setup + gateway = create_test_gateway(auth_type='oauth') + oauth_creds = create_oauth_credentials( + grant_type='authorization_code', + pkce_required=True + ) + + # Step 1: Initiate authorization + auth_result = await oauth_manager.initiate_authorization(oauth_creds) + assert 'authorization_url' in auth_result + assert 'code_challenge' in auth_result['authorization_url'] + + # Step 2: Simulate authorization callback + auth_code = "test_auth_code" + token_result = await oauth_manager.exchange_code( + code=auth_code, + state=auth_result['state'], + credentials=oauth_creds + ) + + assert token_result.access_token + assert token_result.refresh_token + + # Step 3: Use token for tool invocation + tool_result = await tool_service.invoke_tool( + name="test_tool", + arguments={"param": "value"} + ) + + assert tool_result.success + + # Step 4: Test token refresh + # Expire the token + await cache.expire_token(gateway.id) + + # Should automatically refresh + tool_result2 = await tool_service.invoke_tool( + name="test_tool", + arguments={"param": "value2"} + ) + + assert tool_result2.success +``` + +### Security Tests + +```python +# Test bearer token in URL rejection +def test_bearer_token_url_rejection(): + urls = [ + "https://api.com/endpoint?access_token=secret", + "https://api.com/endpoint#access_token=secret", + "https://api.com/endpoint?token=Bearer%20secret", + ] + + for url in urls: + with pytest.raises(ValueError, match="OAuth 2.1.*Bearer tokens.*URLs"): + gateway_service._validate_url(url) + +# Test token encryption +def test_token_encryption(): + cache = TokenCacheManager(encryption_key=test_key) + + token = "sensitive_access_token" + encrypted = cache._encrypt(token) + + # Should not contain original token + assert token not in encrypted + + # Should decrypt correctly + decrypted = cache._decrypt(encrypted) + assert decrypted == token +``` + +## Example Usage + +### Configuration Examples + +#### GitHub with OAuth 2.1 + +```python +POST /gateways +{ + "name": "GitHub MCP Server", + "url": "https://github-mcp.example.com/sse", + "transport": "streamablehttp", + "auth_type": "oauth", + "oauth_config": { + "grant_type": "authorization_code", + "client_id": "github_app_client_id", + "client_secret": "github_app_client_secret", + "authorization_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "redirect_uri": "https://gateway.example.com/oauth/github/callback", + "scopes": ["repo", "read:user", "project"], + "pkce_required": true, + "code_challenge_method": "S256" + } +} +``` + +#### Google Workspace Integration + +```python +POST /gateways +{ + "name": "Google Workspace MCP", + "url": "https://workspace-mcp.example.com/sse", + "transport": "sse", + "auth_type": "oauth", + "oauth_config": { + "grant_type": "authorization_code", + "client_id": "google_client_id.apps.googleusercontent.com", + "client_secret": "google_client_secret", + "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "redirect_uri": "https://gateway.example.com/oauth/google/callback", + "scopes": [ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/gmail.readonly" + ], + "pkce_required": true, + "additional_params": { + "access_type": "offline", + "prompt": "consent" + } + } +} +``` + +### Tool Invocation + +```python +# Automatic OAuth token management +result = await tool_service.invoke_tool( + db=db, + name="github_create_pr", + arguments={ + "repository": "org/repo", + "title": "Feature: OAuth 2.1 Support", + "body": "Implements OAuth 2.1 compliance", + "base": "main", + "head": "feature/oauth21" + } +) + +# The service handles: +# 1. Token acquisition (if needed) +# 2. Token refresh (if expired) +# 3. Retry logic for token expiration +# 4. OAuth 2.1 compliance validation +``` + +## Monitoring and Observability + +### Metrics + +```python +# OAuth 2.1 specific metrics +oauth_metrics = { + # Performance + "oauth.token_request.duration": histogram, + "oauth.token_cache.hit_rate": gauge, + "oauth.pkce_generation.duration": histogram, + + # Security + "oauth.auth_failures.total": counter, + "oauth.token_rotation.count": counter, + "oauth.redirect_uri_mismatches": counter, + "oauth.bearer_in_url_attempts": counter, + + # Usage + "oauth.active_tokens.count": gauge, + "oauth.refresh_operations.total": counter, + "oauth.grant_type.usage": counter(labels=["grant_type"]) +} +``` + +### OpenTelemetry Spans + +```python +# OAuth flow tracing +with create_span("oauth.authorization_flow", { + "oauth.version": "2.1", + "oauth.grant_type": grant_type, + "oauth.provider": provider_name, + "oauth.pkce_enabled": True +}) as span: + # Authorization logic + span.add_event("pkce_generated") + span.add_event("authorization_initiated") + + # ... authorization process ... + + span.set_attribute("oauth.scopes_requested", len(scopes)) + span.set_attribute("oauth.success", success) +``` + +### Audit Events + +```json +{ + "timestamp": "2024-12-10T10:30:45Z", + "event_type": "oauth.token_refreshed", + "gateway_id": "gw_123", + "details": { + "old_token_id": "tok_abc", + "new_token_id": "tok_xyz", + "rotation_number": 3, + "client_id": "client_123", + "ip_address": "192.168.1.100", + "user_agent": "MCP-Gateway/2.0" + } +} +``` + +## Security Best Practices + +### OAuth 2.1 Compliance Checklist + +- [x] PKCE mandatory for all authorization code flows +- [x] S256 code challenge method only +- [x] Exact redirect URI matching +- [x] One-time use refresh tokens +- [x] Automatic token rotation +- [x] No bearer tokens in URLs +- [x] No implicit grant flow support +- [x] No ROPC grant support +- [x] HTTPS required for all OAuth endpoints +- [x] State parameter validation +- [x] Nonce validation for OpenID Connect +- [x] Token binding support (optional) +- [x] DPoP support (future enhancement) + +### Production Security Measures + +1. **Infrastructure** + - Use Hardware Security Module (HSM) for key storage + - Deploy behind WAF with OAuth-specific rules + - Enable comprehensive audit logging + - Implement intrusion detection + +2. **Monitoring** + - Alert on unusual token patterns + - Monitor for authorization anomalies + - Track failed authentication attempts + - Detect potential token replay attacks + +3. **Compliance** + - Regular security audits + - Penetration testing + - Compliance with GDPR/CCPA for token data + - SOC 2 Type II certification considerations + +## Future Enhancements + +1. **Advanced Security** + - DPoP (Demonstrating Proof of Possession) support + - Certificate-bound access tokens + - Mutual TLS for client authentication + - Pushed Authorization Requests (PAR) + +2. **Standards Support** + - OpenID Connect full compliance + - FAPI (Financial-grade API) support + - JWT Secured Authorization Response Mode (JARM) + - Rich Authorization Requests (RAR) + +3. **User Experience** + - OAuth provider auto-discovery + - Simplified configuration wizards + - Mobile app support with PKCE + - Biometric authentication integration + +4. **Operations** + - Automated token lifecycle management + - Self-service OAuth app registration + - Multi-tenant OAuth isolation + - Advanced analytics dashboard + +## Conclusion + +This comprehensive OAuth 2.1 integration design ensures that MCP Gateway implements the highest security standards for delegated authorization. By adopting OAuth 2.1's mandatory security features—including PKCE for all clients, strict redirect URI matching, one-time refresh tokens, and prohibition of bearer tokens in URLs—we create a robust and secure authentication system. + +The implementation provides: +- **Enhanced Security**: Eliminates all known OAuth 2.0 vulnerabilities +- **Simplified Integration**: Fewer flows to support, clearer security model +- **Future-Proof Design**: Ready for emerging standards and extensions +- **Operational Excellence**: Comprehensive monitoring, auditing, and management + +This design enables agents to securely act on behalf of users without the risks associated with personal access tokens, while maintaining backward compatibility where possible and providing a clear migration path for existing deployments. From a8fd95c6ef024d73d5386d2206c9f2f0b0a77889 Mon Sep 17 00:00:00 2001 From: Shamsul Arefin Date: Sat, 16 Aug 2025 16:50:17 +0500 Subject: [PATCH 02/21] oauth 2.0 design Signed-off-by: Shamsul Arefin --- .../architecture/oauth-21-unified-design.md | 1565 ----------------- docs/docs/architecture/oauth-design.md | 421 +++++ 2 files changed, 421 insertions(+), 1565 deletions(-) delete mode 100644 docs/docs/architecture/oauth-21-unified-design.md create mode 100644 docs/docs/architecture/oauth-design.md diff --git a/docs/docs/architecture/oauth-21-unified-design.md b/docs/docs/architecture/oauth-21-unified-design.md deleted file mode 100644 index cc15b375..00000000 --- a/docs/docs/architecture/oauth-21-unified-design.md +++ /dev/null @@ -1,1565 +0,0 @@ -# OAuth 2.1 Integration Design for MCP Gateway - -**Version**: 3.0 (Unified) -**Status**: Draft -**Author**: MCP Gateway Team -**Date**: December 2024 - -## Table of Contents - -1. [Executive Summary](#executive-summary) -2. [Quick Reference: OAuth 2.1 Changes](#quick-reference-oauth-21-changes) -3. [Motivation](#motivation) -4. [Architecture Overview](#architecture-overview) -5. [OAuth 2.1 Specific Requirements](#oauth-21-specific-requirements) -6. [Database Schema Design](#database-schema-design) -7. [Component Design](#component-design) -8. [Implementation Flow](#implementation-flow) -9. [Key Implementation Details](#key-implementation-details) -10. [Security Considerations](#security-considerations) -11. [Configuration](#configuration) -12. [Migration Strategy](#migration-strategy) -13. [Testing Strategy](#testing-strategy) -14. [Rollout Plan](#rollout-plan) -15. [Dependencies](#dependencies) -16. [Example Usage](#example-usage) -17. [Monitoring and Observability](#monitoring-and-observability) -18. [Security Best Practices](#security-best-practices) -19. [Future Enhancements](#future-enhancements) -20. [Conclusion](#conclusion) - -## Executive Summary - -This document provides a comprehensive design for integrating OAuth 2.1 authentication into the MCP Gateway, enabling agents to perform actions on behalf of users without requiring personal access tokens (PATs). The implementation adheres to OAuth 2.1's enhanced security standards, including mandatory PKCE for all clients, strict redirect URI matching, one-time refresh tokens, and prohibition of bearer tokens in URLs. - -## Quick Reference: OAuth 2.1 Changes - -### Key Differences from OAuth 2.0 - -| Feature | OAuth 2.0 | OAuth 2.1 | Impact | -|---------|-----------|-----------|---------| -| **PKCE** | Optional for confidential clients | **Mandatory for ALL clients** | Must implement code_verifier/challenge | -| **Implicit Flow** | Supported | **Completely removed** | Use Authorization Code + PKCE | -| **Resource Owner Password** | Supported | **Strongly discouraged** | Avoid implementation | -| **Redirect URI Matching** | Partial matches allowed | **Exact string match only** | No wildcards permitted | -| **Refresh Tokens** | Can be reused | **One-time use only** | Automatic rotation required | -| **Bearer Tokens in URLs** | Allowed | **Prohibited** | Must use Authorization header | - -### Implementation Checklist - -- [ ] Implement PKCE with S256 for all authorization code flows -- [ ] Remove support for implicit grant flow -- [ ] Implement exact redirect URI validation -- [ ] Add refresh token rotation with immediate invalidation -- [ ] Validate no bearer tokens in URLs -- [ ] Update database schema for OAuth 2.1 requirements -- [ ] Implement OAuth Manager with PKCE support -- [ ] Create Token Cache Manager with rotation -- [ ] Update Admin UI for OAuth 2.1 configuration -- [ ] Modify gateway and tool services for OAuth 2.1 - -## Motivation - -Current limitations of MCP Gateway authentication: - -1. **Security Risk**: Personal Access Tokens (PATs) provide broad access and must be carefully managed -2. **User Experience**: Users must manually create and manage tokens for each service -3. **Scalability**: Managing multiple PATs across different services becomes cumbersome -4. **Delegation**: No native support for agents acting on behalf of users with scoped permissions - -OAuth 2.1 addresses these concerns by providing: -- Enhanced security with mandatory PKCE for all clients -- Removal of vulnerable flows (implicit, ROPC) -- Scoped access control with principle of least privilege -- Secure token refresh with mandatory rotation -- Better security through short-lived access tokens -- Prohibition of bearer tokens in URLs - -## Architecture Overview - -```mermaid -graph TD - subgraph "Admin UI Layer" - A[Gateway Configuration Form] - B[OAuth 2.1 Configuration Fields] - C[Token Management Interface] - end - - subgraph "Database Layer" - D[Gateway Table] - E[OAuth Credentials Table] - F[OAuth Token Cache Table] - G[OAuth PKCE State Table] - end - - subgraph "Service Layer" - H[Gateway Service] - I[Tool Service] - J[OAuth Manager] - K[Token Cache Manager] - L[PKCE Manager] - end - - subgraph "External Systems" - M[OAuth Provider] - N[MCP Server] - end - - A --> B - B --> D - B --> E - - H --> J - I --> J - J --> K - J --> L - J --> M - - K --> F - L --> G - H --> N - I --> N - - J -.->|Uses| O[oauthlib/authlib] -``` - -## OAuth 2.1 Specific Requirements - -### Implementation Challenges - -1. **Library Support**: Not all OAuth libraries fully support OAuth 2.1 yet - - `oauthlib` may need patches or extensions - - Consider `authlib` which has better OAuth 2.1 support - - May need to implement custom PKCE handling - -2. **Provider Compatibility**: Some OAuth providers may not fully support OAuth 2.1 - - GitHub, Google, and modern providers generally support PKCE - - Legacy enterprise systems may need adaptation layer - -3. **Breaking Changes**: OAuth 2.1 is not backward compatible - - Existing OAuth 2.0 integrations will need updates - - Cannot support both OAuth 2.0 and 2.1 simultaneously for same flow - -4. **Performance Considerations**: - - Token rotation adds database operations - - PKCE adds computational overhead (minimal) - - State management requires additional storage - -### Mandatory Requirements - -1. **PKCE (Proof Key for Code Exchange)** - - Required for ALL OAuth clients, including confidential clients - - Must use S256 (SHA-256) challenge method - - Code verifier length: 43-128 characters - -2. **Grant Type Restrictions** - - Only Authorization Code (with PKCE) and Client Credentials allowed - - Implicit Grant flow completely removed - - Resource Owner Password Credentials strongly discouraged - -3. **Security Enhancements** - - Exact redirect URI matching (no wildcards) - - One-time use refresh tokens with rotation - - No bearer tokens in URL query strings or fragments - - Mandatory HTTPS for all OAuth endpoints - -## Database Schema Design - -### Complete Schema Overview - -```mermaid -erDiagram - GATEWAY { - string id PK - string name UK - string url - string transport "SSE|STREAMABLEHTTP" - string auth_type "basic|bearer|oauth|authheaders" - json auth_value "encrypted" - json oauth_config "OAuth 2.1 settings" - boolean enabled - boolean reachable - datetime created_at - datetime updated_at - } - - OAUTH_CREDENTIALS { - string id PK - string gateway_id FK - string client_id - string client_secret "encrypted" - string authorization_url - string token_url - string redirect_uri "NOT NULL - exact match" - json scopes - string grant_type "authorization_code|client_credentials" - boolean pkce_required "default true" - string code_challenge_method "S256 only" - datetime created_at - datetime updated_at - } - - OAUTH_TOKEN_CACHE { - string id PK - string gateway_id FK - string access_token "encrypted" - string refresh_token "encrypted" - datetime expires_at - boolean is_active "for rotation tracking" - string previous_refresh_token_id FK - json token_metadata - datetime created_at - datetime used_at - datetime invalidated_at - } - - OAUTH_PKCE_STATE { - string id PK - string gateway_id FK - string state UK "cryptographically secure" - string code_verifier "encrypted" - string code_challenge - string redirect_uri "must match exactly" - string nonce "for replay protection" - datetime expires_at - datetime created_at - } - - TOOL { - string id PK - string gateway_id FK - string name - string auth_type - string auth_value "encrypted" - } - - GATEWAY ||--o| OAUTH_CREDENTIALS : has - GATEWAY ||--o{ OAUTH_TOKEN_CACHE : uses - GATEWAY ||--o{ OAUTH_PKCE_STATE : tracks - GATEWAY ||--o{ TOOL : provides - OAUTH_TOKEN_CACHE ||--o| OAUTH_TOKEN_CACHE : replaces -``` - -### SQL Schema Definitions - -```sql --- OAuth Credentials Table -CREATE TABLE oauth_credentials ( - id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), - gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, - client_id VARCHAR(255) NOT NULL, - client_secret TEXT, -- Encrypted with AES-256-GCM - authorization_url TEXT, - token_url TEXT NOT NULL, - redirect_uri TEXT NOT NULL, -- OAuth 2.1: Exact match required - scopes JSON DEFAULT '[]', - grant_type VARCHAR(50) DEFAULT 'authorization_code', - pkce_required BOOLEAN DEFAULT TRUE, -- OAuth 2.1: Always true - code_challenge_method VARCHAR(10) DEFAULT 'S256', -- OAuth 2.1: Only S256 - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - UNIQUE(gateway_id), - CHECK (grant_type IN ('authorization_code', 'client_credentials')), - CHECK (code_challenge_method = 'S256') -); - --- PKCE State Table -CREATE TABLE oauth_pkce_state ( - id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), - gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, - state VARCHAR(255) UNIQUE NOT NULL, - code_verifier VARCHAR(128) NOT NULL, -- Encrypted - code_challenge VARCHAR(128) NOT NULL, - redirect_uri TEXT NOT NULL, - nonce VARCHAR(255), -- For OpenID Connect - expires_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - INDEX idx_state_expires (state, expires_at), - INDEX idx_gateway_state (gateway_id, state) -); - --- Token Cache Table with Rotation Support -CREATE TABLE oauth_token_cache ( - id VARCHAR(36) PRIMARY KEY DEFAULT gen_random_uuid(), - gateway_id VARCHAR(36) REFERENCES gateways(id) ON DELETE CASCADE, - access_token TEXT NOT NULL, -- Encrypted - refresh_token TEXT, -- Encrypted - expires_at TIMESTAMP WITH TIME ZONE, - is_active BOOLEAN DEFAULT TRUE, - previous_refresh_token_id VARCHAR(36) REFERENCES oauth_token_cache(id), - token_metadata JSON, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - used_at TIMESTAMP WITH TIME ZONE, - invalidated_at TIMESTAMP WITH TIME ZONE, - INDEX idx_gateway_active (gateway_id, is_active), - INDEX idx_expires (expires_at), - INDEX idx_refresh_token (refresh_token) -- For rotation lookup -); -``` - -## Component Design - -### 1. OAuth Manager Service - -**Location**: `mcpgateway/services/oauth_manager.py` - -```python -from oauthlib.oauth2 import BackendApplicationClient, WebApplicationClient -from requests_oauthlib import OAuth2Session -from typing import Optional, Dict, Any, Tuple -import secrets -import hashlib -import base64 -import time - -class OAuthManager: - """Manages OAuth 2.1 authentication flows with mandatory PKCE.""" - - def __init__(self, cache_manager: TokenCacheManager, db_session): - self.cache_manager = cache_manager - self.db = db_session - - async def get_access_token( - self, - gateway_id: str, - credentials: OAuthCredentials - ) -> str: - """Get valid access token, refreshing if necessary.""" - # Check cache first - cached_token = await self.cache_manager.get_token(gateway_id) - if cached_token and not self._is_expired(cached_token): - return cached_token.access_token - - # Token expired or not found - if cached_token and cached_token.refresh_token: - return await self.refresh_token( - cached_token.refresh_token, - credentials - ) - else: - # Need new authorization - return await self.initiate_authorization(credentials) - - async def initiate_authorization( - self, - credentials: OAuthCredentials - ) -> Dict[str, str]: - """Start OAuth 2.1 authorization with PKCE.""" - # Generate PKCE parameters - code_verifier, code_challenge = self.generate_pkce_pair() - - # Generate secure state - state = secrets.token_urlsafe(32) - - # Store PKCE state - await self._store_pkce_state( - credentials.gateway_id, - state, - code_verifier, - code_challenge, - credentials.redirect_uri - ) - - # Build authorization URL - auth_params = { - 'response_type': 'code', - 'client_id': credentials.client_id, - 'redirect_uri': credentials.redirect_uri, - 'scope': ' '.join(credentials.scopes), - 'state': state, - 'code_challenge': code_challenge, - 'code_challenge_method': 'S256' - } - - return { - 'authorization_url': self._build_auth_url( - credentials.authorization_url, - auth_params - ), - 'state': state - } - - async def exchange_code( - self, - code: str, - state: str, - credentials: OAuthCredentials - ) -> TokenResponse: - """Exchange authorization code for tokens with PKCE verification.""" - # Retrieve and validate PKCE state - pkce_state = await self._get_pkce_state(state) - if not pkce_state: - raise ValueError("Invalid or expired state") - - # Validate redirect URI exact match - if pkce_state.redirect_uri != credentials.redirect_uri: - raise ValueError("OAuth 2.1: Redirect URI must match exactly") - - # Exchange code for tokens - token_data = { - 'grant_type': 'authorization_code', - 'code': code, - 'redirect_uri': credentials.redirect_uri, - 'code_verifier': pkce_state.code_verifier, - 'client_id': credentials.client_id, - 'client_secret': credentials.client_secret - } - - response = await self._request_token( - credentials.token_url, - token_data - ) - - # Store tokens with rotation tracking - await self.cache_manager.store_token( - credentials.gateway_id, - response, - invalidate_previous=True - ) - - # Clean up PKCE state - await self._delete_pkce_state(state) - - return response - - async def refresh_token( - self, - refresh_token: str, - credentials: OAuthCredentials - ) -> TokenResponse: - """OAuth 2.1: Refresh with mandatory rotation.""" - # Immediately invalidate old refresh token - await self.cache_manager.invalidate_refresh_token(refresh_token) - - token_data = { - 'grant_type': 'refresh_token', - 'refresh_token': refresh_token, - 'client_id': credentials.client_id, - 'client_secret': credentials.client_secret - } - - response = await self._request_token( - credentials.token_url, - token_data - ) - - # Store new tokens, invalidating all previous - await self.cache_manager.store_token( - credentials.gateway_id, - response, - invalidate_previous=True - ) - - return response - - def generate_pkce_pair(self) -> Tuple[str, str]: - """Generate PKCE code verifier and challenge (S256).""" - # Generate cryptographically secure verifier - code_verifier = base64.urlsafe_b64encode( - secrets.token_bytes(64) # 96 chars after encoding - ).decode('utf-8').rstrip('=') - - # Create S256 challenge - code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode('utf-8')).digest() - ).decode('utf-8').rstrip('=') - - return code_verifier, code_challenge - - def validate_redirect_uri( - self, - requested_uri: str, - registered_uri: str - ) -> bool: - """OAuth 2.1: Exact string match for redirect URIs.""" - return requested_uri == registered_uri - - def validate_token_url(self, url: str) -> bool: - """OAuth 2.1: Ensure no tokens in URLs.""" - forbidden_params = ['access_token', 'refresh_token', 'token'] - parsed = urlparse(url) - - # Check query string - if parsed.query: - params = parse_qs(parsed.query) - for forbidden in forbidden_params: - if forbidden in params: - return False - - # Check fragment - if parsed.fragment: - for forbidden in forbidden_params: - if forbidden in parsed.fragment: - return False - - return True -``` - -### 2. Token Cache Manager - -**Location**: `mcpgateway/services/token_cache_manager.py` - -```python -from cryptography.fernet import Fernet -from datetime import datetime, timedelta, timezone -import json - -class TokenCacheManager: - """OAuth 2.1 token cache with rotation and encryption.""" - - def __init__(self, db_session, encryption_key: bytes): - self.db = db_session - self.cipher = Fernet(encryption_key) - - async def get_token(self, gateway_id: str) -> Optional[OAuthToken]: - """Get active token if valid.""" - token = self.db.query(OAuthTokenCache).filter( - OAuthTokenCache.gateway_id == gateway_id, - OAuthTokenCache.is_active == True, - OAuthTokenCache.invalidated_at.is_(None) - ).first() - - if not token: - return None - - # Check expiration - if token.expires_at and token.expires_at < datetime.now(timezone.utc): - await self.invalidate_token(gateway_id) - return None - - # Decrypt tokens - return OAuthToken( - access_token=self._decrypt(token.access_token), - refresh_token=self._decrypt(token.refresh_token) if token.refresh_token else None, - expires_at=token.expires_at, - metadata=token.token_metadata - ) - - async def store_token( - self, - gateway_id: str, - token: TokenResponse, - invalidate_previous: bool = True - ) -> None: - """Store token with OAuth 2.1 rotation support.""" - if invalidate_previous: - # Invalidate all previous tokens for this gateway - await self.invalidate_all_tokens(gateway_id) - - # Calculate expiration - expires_at = datetime.now(timezone.utc) + timedelta( - seconds=token.expires_in - ) if token.expires_in else None - - # Get previous token ID for rotation tracking - previous_token = self.db.query(OAuthTokenCache).filter( - OAuthTokenCache.gateway_id == gateway_id, - OAuthTokenCache.is_active == True - ).first() - - # Create new token entry - new_token = OAuthTokenCache( - gateway_id=gateway_id, - access_token=self._encrypt(token.access_token), - refresh_token=self._encrypt(token.refresh_token) if token.refresh_token else None, - expires_at=expires_at, - is_active=True, - previous_refresh_token_id=previous_token.id if previous_token else None, - token_metadata={ - 'token_type': token.token_type, - 'scope': token.scope, - 'issued_at': datetime.now(timezone.utc).isoformat() - } - ) - - self.db.add(new_token) - self.db.commit() - - async def invalidate_refresh_token(self, refresh_token: str) -> None: - """OAuth 2.1: Immediately invalidate used refresh token.""" - encrypted_token = self._encrypt(refresh_token) - - token_entry = self.db.query(OAuthTokenCache).filter( - OAuthTokenCache.refresh_token == encrypted_token, - OAuthTokenCache.is_active == True - ).first() - - if token_entry: - token_entry.is_active = False - token_entry.invalidated_at = datetime.now(timezone.utc) - self.db.commit() - - async def invalidate_all_tokens(self, gateway_id: str) -> None: - """Invalidate all tokens for rotation.""" - self.db.query(OAuthTokenCache).filter( - OAuthTokenCache.gateway_id == gateway_id, - OAuthTokenCache.is_active == True - ).update({ - 'is_active': False, - 'invalidated_at': datetime.now(timezone.utc) - }) - self.db.commit() - - def _encrypt(self, data: str) -> str: - """Encrypt sensitive data.""" - return self.cipher.encrypt(data.encode()).decode() - - def _decrypt(self, data: str) -> str: - """Decrypt sensitive data.""" - return self.cipher.decrypt(data.encode()).decode() -``` - -### 3. Admin UI Enhancements - -**OAuth 2.1 Configuration Form Fields**: - -```html - - - -``` - -## Implementation Flow - -### OAuth 2.1 Complete Flow - -```mermaid -sequenceDiagram - participant U as User - participant UI as Admin UI - participant GS as Gateway Service - participant OM as OAuth Manager - participant TC as Token Cache - participant PS as PKCE State Store - participant OP as OAuth Provider - participant MCP as MCP Server - - Note over U,MCP: Gateway Configuration - U->>UI: Configure OAuth Gateway - UI->>UI: Validate OAuth 2.1 Requirements - UI->>GS: Save OAuth Config - GS->>GS: Validate redirect URI format - GS->>OM: Initialize OAuth Manager - - Note over U,MCP: Authorization Code Flow with PKCE - U->>GS: Request Tool Access - GS->>TC: Check Token Cache - TC-->>GS: No Valid Token - - GS->>OM: Initiate Authorization - OM->>OM: Generate PKCE Pair - OM->>PS: Store PKCE State - OM-->>GS: Authorization URL + State - GS-->>U: Redirect to OAuth Provider - - U->>OP: Authenticate - OP->>OP: Validate PKCE Challenge - OP-->>U: Authorization Code - U-->>GS: Redirect with Code + State - - GS->>OM: Exchange Code - OM->>PS: Retrieve PKCE Verifier - PS-->>OM: Code Verifier - OM->>OM: Validate State - OM->>OP: Token Request + PKCE Verifier - OP->>OP: Validate PKCE - OP-->>OM: Access + Refresh Tokens - - OM->>TC: Store Tokens (Rotate) - TC->>TC: Invalidate Previous Tokens - OM->>PS: Clean PKCE State - OM-->>GS: Access Token - - Note over U,MCP: Tool Invocation - GS->>GS: Validate No Token in URL - GS->>MCP: Call Tool (Bearer Header) - MCP-->>GS: Tool Response - GS-->>U: Return Result - - Note over U,MCP: Token Refresh with Rotation - U->>GS: Invoke Tool (Later) - GS->>TC: Check Token Cache - TC-->>GS: Token Expired - - GS->>OM: Refresh Token - OM->>TC: Mark Old Token Used - OM->>OP: Refresh Request - OP-->>OM: New Tokens - OM->>TC: Store New Tokens - TC->>TC: Invalidate Old Tokens - OM-->>GS: New Access Token -``` - -## Key Implementation Details - -### 1. Gateway Service Modifications - -**File**: `mcpgateway/services/gateway_service.py` - -```python -async def _initialize_gateway( - self, - url: str, - authentication: Optional[Dict[str, str]] = None, - transport: str = "SSE" -) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: - """Initialize gateway with OAuth 2.1 support.""" - - # OAuth 2.1: Validate URL has no tokens - if self._contains_bearer_token(url): - raise ValueError( - "OAuth 2.1 Security Violation: Bearer tokens must not be passed in URLs. " - "Use Authorization header instead." - ) - - headers = {} - - if authentication and authentication.get('type') == 'oauth': - # Get OAuth credentials - credentials = await self._get_oauth_credentials( - authentication['gateway_id'] - ) - - # Validate redirect URI if present - if hasattr(credentials, 'redirect_uri') and credentials.redirect_uri: - parsed_url = urlparse(url) - parsed_redirect = urlparse(credentials.redirect_uri) - - # OAuth 2.1: Exact match validation - if parsed_url.scheme != parsed_redirect.scheme or \ - parsed_url.netloc != parsed_redirect.netloc or \ - parsed_url.path != parsed_redirect.path: - raise ValueError( - f"OAuth 2.1: Redirect URI mismatch. " - f"Expected: {credentials.redirect_uri}, Got: {url}" - ) - - # Get access token - try: - access_token = await self.oauth_manager.get_access_token( - gateway_id=authentication['gateway_id'], - credentials=credentials - ) - - # OAuth 2.1: Always use Authorization header - headers = { - 'Authorization': f'Bearer {access_token}', - 'X-OAuth-Version': '2.1' # Indicate OAuth 2.1 compliance - } - except TokenExpiredError: - # Handle expired token - logger.warning(f"Token expired for gateway {authentication['gateway_id']}") - raise GatewayConnectionError("OAuth token expired, re-authorization required") - - else: - # Non-OAuth authentication - headers = decode_auth(authentication) - - # Continue with connection... - return await self._connect_to_gateway(url, headers, transport) - -def _contains_bearer_token(self, url: str) -> bool: - """Check if URL contains bearer tokens (OAuth 2.1 violation).""" - parsed = urlparse(url) - - # Check query parameters - if parsed.query: - params = parse_qs(parsed.query.lower()) - token_params = ['access_token', 'bearer', 'token', 'auth'] - for param in token_params: - if param in params: - return True - - # Check fragment - if parsed.fragment: - fragment_lower = parsed.fragment.lower() - if any(token in fragment_lower for token in ['access_token', 'bearer', 'token']): - return True - - return False -``` - -### 2. Tool Service Modifications - -**File**: `mcpgateway/services/tool_service.py` - -```python -async def invoke_tool( - self, - db: Session, - name: str, - arguments: Dict[str, Any], - request_headers: Optional[Dict[str, str]] = None -) -> ToolResult: - """Invoke tool with OAuth 2.1 compliance.""" - - tool = await self.get_tool_by_name(db, name) - - # OAuth 2.1: Pre-flight URL validation - if tool.url and self._contains_bearer_token(tool.url): - raise ToolValidationError( - f"OAuth 2.1 Violation: Tool URL contains bearer token. " - f"Tool: {name}, URL: {tool.url}" - ) - - headers = {} - - if tool.auth_type == 'oauth': - gateway = tool.gateway - if not gateway: - raise ToolError(f"OAuth tool {name} has no associated gateway") - - # Get OAuth credentials - credentials = await self._get_oauth_credentials(db, gateway.id) - - # Token acquisition with retry logic - max_retries = 2 - for attempt in range(max_retries): - try: - # Get access token (handles refresh automatically) - access_token = await self.oauth_manager.get_access_token( - gateway_id=gateway.id, - credentials=credentials - ) - - # OAuth 2.1: Bearer token in Authorization header only - headers = { - 'Authorization': f'Bearer {access_token}', - 'X-Tool-Name': name, - 'X-OAuth-Version': '2.1' - } - - break - - except TokenExpiredError: - if attempt == max_retries - 1: - raise - # Token expired during processing, retry - logger.info(f"Token expired during tool invocation, retrying... (attempt {attempt + 1})") - await asyncio.sleep(0.5) - - except RefreshTokenExpiredError: - # Refresh token expired, need re-authorization - raise ToolError( - f"OAuth refresh token expired for tool {name}. " - "Re-authorization required." - ) - else: - # Non-OAuth authentication - headers = self._get_tool_headers(tool) - - # OAuth 2.1: Final header validation - self._validate_oauth_headers(headers) - - # Merge with request headers (OAuth headers take precedence) - if request_headers: - merged_headers = {**request_headers, **headers} - else: - merged_headers = headers - - # Execute tool with observability - with create_span("tool.invoke.oauth21", { - "tool.name": name, - "tool.auth_type": tool.auth_type, - "oauth.version": "2.1" if tool.auth_type == 'oauth' else None - }) as span: - try: - result = await self._execute_tool( - tool=tool, - arguments=arguments, - headers=merged_headers - ) - - if span: - span.set_attribute("tool.success", True) - - return result - - except Exception as e: - if span: - span.set_attribute("tool.success", False) - span.set_attribute("error.message", str(e)) - raise - -def _validate_oauth_headers(self, headers: Dict[str, str]) -> None: - """Validate OAuth 2.1 header compliance.""" - for key, value in headers.items(): - if key.lower() != 'authorization': - # Check for tokens in other headers - value_lower = value.lower() - if any(token in value_lower for token in ['bearer', 'access_token', 'token']): - logger.warning( - f"Potential OAuth 2.1 violation: Token-like value in header '{key}'. " - "Tokens should only be in Authorization header." - ) -``` - -## Security Considerations - -### OAuth 2.1 Security Requirements - -1. **Token Storage** - - Use AES-256-GCM for all token encryption - - Separate encryption keys for different token types - - Key rotation every 90 days - -2. **PKCE Implementation** - - Code verifier: 43-128 characters (recommended: 96) - - Use cryptographically secure random generation - - S256 challenge method only - -3. **Redirect URI Security** - - Store as exact strings in database - - Validate character-by-character match - - No URL encoding differences allowed - -4. **Refresh Token Rotation** - - Immediate invalidation of used tokens - - Track token lineage for audit - - Maximum rotation chain length: 10 - -5. **Rate Limiting** - - Authorization attempts: 5 per minute per client - - Token requests: 10 per minute per client - - Exponential backoff on failures - -6. **Audit Logging** - ```python - # Required audit events - - Authorization initiated - - Authorization completed/failed - - Token issued - - Token refreshed - - Token invalidated - - Suspicious activity detected - ``` - -## Configuration - -### Environment Variables - -```env -# OAuth 2.1 Configuration -OAUTH_ENABLE=true -OAUTH_VERSION=2.1 - -# Token Management -OAUTH_TOKEN_CACHE_TTL=3600 # Access token cache TTL (seconds) -OAUTH_REFRESH_TOKEN_TTL=604800 # Refresh token validity (7 days) -OAUTH_MAX_TOKEN_AGE=86400 # Maximum token age before forced refresh - -# PKCE Configuration -OAUTH_PKCE_REQUIRED=true # Always true for OAuth 2.1 -OAUTH_PKCE_CODE_LENGTH=96 # Code verifier length (43-128) -OAUTH_PKCE_STATE_TTL=600 # Authorization state TTL (seconds) - -# Security Settings -OAUTH_STRICT_REDIRECT_URI=true # Always true for OAuth 2.1 -OAUTH_REFRESH_TOKEN_ROTATION=true # Always true for OAuth 2.1 -OAUTH_ALLOW_HTTP=false # HTTPS required for OAuth endpoints - -# Rate Limiting -OAUTH_AUTH_RATE_LIMIT=5 # Auth attempts per minute -OAUTH_TOKEN_RATE_LIMIT=10 # Token requests per minute -OAUTH_RATE_LIMIT_WINDOW=60 # Rate limit window (seconds) - -# Encryption -OAUTH_ENCRYPTION_KEY=${OAUTH_ENCRYPTION_KEY} # Base64 encoded 32-byte key -OAUTH_KEY_ROTATION_DAYS=90 # Encryption key rotation period - -# Monitoring -OAUTH_ENABLE_METRICS=true -OAUTH_ENABLE_AUDIT_LOG=true -OAUTH_AUDIT_LOG_RETENTION_DAYS=90 -``` - -## Migration Strategy - -### Phase 1: Database Schema (Week 1) - -```sql --- Migration script -BEGIN TRANSACTION; - --- Add OAuth config to gateways -ALTER TABLE gateways -ADD COLUMN oauth_config JSONB, -ADD COLUMN oauth_version VARCHAR(10) DEFAULT '2.1'; - --- Create OAuth tables -CREATE TABLE oauth_credentials (...); -CREATE TABLE oauth_token_cache (...); -CREATE TABLE oauth_pkce_state (...); - --- Add indexes -CREATE INDEX idx_oauth_active_tokens ON oauth_token_cache(gateway_id, is_active); -CREATE INDEX idx_oauth_pkce_expiry ON oauth_pkce_state(expires_at); - --- Add constraints -ALTER TABLE oauth_credentials -ADD CONSTRAINT chk_oauth21_grant_type -CHECK (grant_type IN ('authorization_code', 'client_credentials')); - -COMMIT; -``` - -### Phase 2: Service Implementation (Week 2-3) - -1. Implement OAuth Manager with PKCE -2. Implement Token Cache Manager with rotation -3. Add OAuth support to Gateway Service -4. Add OAuth support to Tool Service -5. Create OAuth callback endpoints - -### Phase 3: UI Integration (Week 4) - -1. Update Admin UI with OAuth 2.1 forms -2. Add OAuth status dashboard -3. Implement authorization flow UI -4. Add token management interface - -### Phase 4: Testing & Rollout (Week 5-6) - -1. Unit tests for all OAuth components -2. Integration tests for complete flows -3. Security penetration testing -4. Performance testing under load -5. Documentation and training - -## Testing Strategy - -### Unit Tests - -```python -# Test PKCE generation -def test_pkce_generation(): - manager = OAuthManager() - verifier, challenge = manager.generate_pkce_pair() - - assert 43 <= len(verifier) <= 128 - assert len(challenge) == 43 # S256 produces 43 chars - - # Verify challenge calculation - expected = base64.urlsafe_b64encode( - hashlib.sha256(verifier.encode()).digest() - ).decode().rstrip('=') - - assert challenge == expected - -# Test redirect URI validation -def test_redirect_uri_validation(): - manager = OAuthManager() - - # Exact match - should pass - assert manager.validate_redirect_uri( - "https://app.com/callback", - "https://app.com/callback" - ) - - # Different path - should fail - assert not manager.validate_redirect_uri( - "https://app.com/callback/auth", - "https://app.com/callback" - ) - - # Different scheme - should fail - assert not manager.validate_redirect_uri( - "http://app.com/callback", - "https://app.com/callback" - ) - -# Test token rotation -async def test_refresh_token_rotation(): - cache = TokenCacheManager() - - # Store initial token - await cache.store_token("gw1", initial_token) - - # Refresh should invalidate old token - await cache.invalidate_refresh_token(initial_token.refresh_token) - - # Old token should be inactive - old_token = await cache.get_token("gw1") - assert old_token is None -``` - -### Integration Tests - -```python -# Complete OAuth 2.1 flow test -async def test_oauth21_complete_flow(): - # Setup - gateway = create_test_gateway(auth_type='oauth') - oauth_creds = create_oauth_credentials( - grant_type='authorization_code', - pkce_required=True - ) - - # Step 1: Initiate authorization - auth_result = await oauth_manager.initiate_authorization(oauth_creds) - assert 'authorization_url' in auth_result - assert 'code_challenge' in auth_result['authorization_url'] - - # Step 2: Simulate authorization callback - auth_code = "test_auth_code" - token_result = await oauth_manager.exchange_code( - code=auth_code, - state=auth_result['state'], - credentials=oauth_creds - ) - - assert token_result.access_token - assert token_result.refresh_token - - # Step 3: Use token for tool invocation - tool_result = await tool_service.invoke_tool( - name="test_tool", - arguments={"param": "value"} - ) - - assert tool_result.success - - # Step 4: Test token refresh - # Expire the token - await cache.expire_token(gateway.id) - - # Should automatically refresh - tool_result2 = await tool_service.invoke_tool( - name="test_tool", - arguments={"param": "value2"} - ) - - assert tool_result2.success -``` - -### Security Tests - -```python -# Test bearer token in URL rejection -def test_bearer_token_url_rejection(): - urls = [ - "https://api.com/endpoint?access_token=secret", - "https://api.com/endpoint#access_token=secret", - "https://api.com/endpoint?token=Bearer%20secret", - ] - - for url in urls: - with pytest.raises(ValueError, match="OAuth 2.1.*Bearer tokens.*URLs"): - gateway_service._validate_url(url) - -# Test token encryption -def test_token_encryption(): - cache = TokenCacheManager(encryption_key=test_key) - - token = "sensitive_access_token" - encrypted = cache._encrypt(token) - - # Should not contain original token - assert token not in encrypted - - # Should decrypt correctly - decrypted = cache._decrypt(encrypted) - assert decrypted == token -``` - -## Example Usage - -### Configuration Examples - -#### GitHub with OAuth 2.1 - -```python -POST /gateways -{ - "name": "GitHub MCP Server", - "url": "https://github-mcp.example.com/sse", - "transport": "streamablehttp", - "auth_type": "oauth", - "oauth_config": { - "grant_type": "authorization_code", - "client_id": "github_app_client_id", - "client_secret": "github_app_client_secret", - "authorization_url": "https://github.com/login/oauth/authorize", - "token_url": "https://github.com/login/oauth/access_token", - "redirect_uri": "https://gateway.example.com/oauth/github/callback", - "scopes": ["repo", "read:user", "project"], - "pkce_required": true, - "code_challenge_method": "S256" - } -} -``` - -#### Google Workspace Integration - -```python -POST /gateways -{ - "name": "Google Workspace MCP", - "url": "https://workspace-mcp.example.com/sse", - "transport": "sse", - "auth_type": "oauth", - "oauth_config": { - "grant_type": "authorization_code", - "client_id": "google_client_id.apps.googleusercontent.com", - "client_secret": "google_client_secret", - "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", - "token_url": "https://oauth2.googleapis.com/token", - "redirect_uri": "https://gateway.example.com/oauth/google/callback", - "scopes": [ - "https://www.googleapis.com/auth/drive.readonly", - "https://www.googleapis.com/auth/gmail.readonly" - ], - "pkce_required": true, - "additional_params": { - "access_type": "offline", - "prompt": "consent" - } - } -} -``` - -### Tool Invocation - -```python -# Automatic OAuth token management -result = await tool_service.invoke_tool( - db=db, - name="github_create_pr", - arguments={ - "repository": "org/repo", - "title": "Feature: OAuth 2.1 Support", - "body": "Implements OAuth 2.1 compliance", - "base": "main", - "head": "feature/oauth21" - } -) - -# The service handles: -# 1. Token acquisition (if needed) -# 2. Token refresh (if expired) -# 3. Retry logic for token expiration -# 4. OAuth 2.1 compliance validation -``` - -## Monitoring and Observability - -### Metrics - -```python -# OAuth 2.1 specific metrics -oauth_metrics = { - # Performance - "oauth.token_request.duration": histogram, - "oauth.token_cache.hit_rate": gauge, - "oauth.pkce_generation.duration": histogram, - - # Security - "oauth.auth_failures.total": counter, - "oauth.token_rotation.count": counter, - "oauth.redirect_uri_mismatches": counter, - "oauth.bearer_in_url_attempts": counter, - - # Usage - "oauth.active_tokens.count": gauge, - "oauth.refresh_operations.total": counter, - "oauth.grant_type.usage": counter(labels=["grant_type"]) -} -``` - -### OpenTelemetry Spans - -```python -# OAuth flow tracing -with create_span("oauth.authorization_flow", { - "oauth.version": "2.1", - "oauth.grant_type": grant_type, - "oauth.provider": provider_name, - "oauth.pkce_enabled": True -}) as span: - # Authorization logic - span.add_event("pkce_generated") - span.add_event("authorization_initiated") - - # ... authorization process ... - - span.set_attribute("oauth.scopes_requested", len(scopes)) - span.set_attribute("oauth.success", success) -``` - -### Audit Events - -```json -{ - "timestamp": "2024-12-10T10:30:45Z", - "event_type": "oauth.token_refreshed", - "gateway_id": "gw_123", - "details": { - "old_token_id": "tok_abc", - "new_token_id": "tok_xyz", - "rotation_number": 3, - "client_id": "client_123", - "ip_address": "192.168.1.100", - "user_agent": "MCP-Gateway/2.0" - } -} -``` - -## Security Best Practices - -### OAuth 2.1 Compliance Checklist - -- [x] PKCE mandatory for all authorization code flows -- [x] S256 code challenge method only -- [x] Exact redirect URI matching -- [x] One-time use refresh tokens -- [x] Automatic token rotation -- [x] No bearer tokens in URLs -- [x] No implicit grant flow support -- [x] No ROPC grant support -- [x] HTTPS required for all OAuth endpoints -- [x] State parameter validation -- [x] Nonce validation for OpenID Connect -- [x] Token binding support (optional) -- [x] DPoP support (future enhancement) - -### Production Security Measures - -1. **Infrastructure** - - Use Hardware Security Module (HSM) for key storage - - Deploy behind WAF with OAuth-specific rules - - Enable comprehensive audit logging - - Implement intrusion detection - -2. **Monitoring** - - Alert on unusual token patterns - - Monitor for authorization anomalies - - Track failed authentication attempts - - Detect potential token replay attacks - -3. **Compliance** - - Regular security audits - - Penetration testing - - Compliance with GDPR/CCPA for token data - - SOC 2 Type II certification considerations - -## Future Enhancements - -1. **Advanced Security** - - DPoP (Demonstrating Proof of Possession) support - - Certificate-bound access tokens - - Mutual TLS for client authentication - - Pushed Authorization Requests (PAR) - -2. **Standards Support** - - OpenID Connect full compliance - - FAPI (Financial-grade API) support - - JWT Secured Authorization Response Mode (JARM) - - Rich Authorization Requests (RAR) - -3. **User Experience** - - OAuth provider auto-discovery - - Simplified configuration wizards - - Mobile app support with PKCE - - Biometric authentication integration - -4. **Operations** - - Automated token lifecycle management - - Self-service OAuth app registration - - Multi-tenant OAuth isolation - - Advanced analytics dashboard - -## Conclusion - -This comprehensive OAuth 2.1 integration design ensures that MCP Gateway implements the highest security standards for delegated authorization. By adopting OAuth 2.1's mandatory security features—including PKCE for all clients, strict redirect URI matching, one-time refresh tokens, and prohibition of bearer tokens in URLs—we create a robust and secure authentication system. - -The implementation provides: -- **Enhanced Security**: Eliminates all known OAuth 2.0 vulnerabilities -- **Simplified Integration**: Fewer flows to support, clearer security model -- **Future-Proof Design**: Ready for emerging standards and extensions -- **Operational Excellence**: Comprehensive monitoring, auditing, and management - -This design enables agents to securely act on behalf of users without the risks associated with personal access tokens, while maintaining backward compatibility where possible and providing a clear migration path for existing deployments. diff --git a/docs/docs/architecture/oauth-design.md b/docs/docs/architecture/oauth-design.md new file mode 100644 index 00000000..bbcc7e00 --- /dev/null +++ b/docs/docs/architecture/oauth-design.md @@ -0,0 +1,421 @@ +# OAuth 2.0 Integration Design for MCP Gateway + +**Version**: 1.0 +**Status**: Draft +**Date**: December 2024 + +## Executive Summary + +This document outlines the design for integrating OAuth 2.0 authentication into the MCP Gateway, enabling agents to perform actions on behalf of users without requiring personal access tokens (PATs). The implementation will use the `oauthlib` library and support Client Credentials and Authorization Code flows following OAuth 2.0 best practices. + +## Motivation + +Current limitations: +- Personal Access Tokens (PATs) provide broad access with security risks +- Manual token management across multiple services +- No native support for delegated authorization with scoped permissions + +OAuth 2.0 provides: +- Standardized authentication flows +- Scoped access control +- Temporary access without storing user credentials +- Industry-standard security practices + +## Architecture Overview + +```mermaid +graph TD + subgraph "MCP Gateway" + A[Admin UI] + B[Gateway Service] + C[Tool Service] + D[OAuth Manager] + end + + subgraph "Storage" + E[Database] + end + + subgraph "External" + F[OAuth Provider] + G[MCP Server] + end + + A --> E + B --> D + C --> D + D --> F + B --> G + C --> G + + D -.->|Uses| H[oauthlib] +``` + +## Database Schema + +### Modified Gateway Table + +```sql +ALTER TABLE gateways +ADD COLUMN oauth_config JSON; + +-- OAuth config structure: +{ + "grant_type": "client_credentials|authorization_code", + "client_id": "string", + "client_secret": "encrypted_string", + "authorization_url": "string", + "token_url": "string", + "redirect_uri": "string", + "scopes": ["scope1", "scope2"] +} +``` + +## Core Components + +### 1. OAuth Manager Service + +**Location**: `mcpgateway/services/oauth_manager.py` + +```python +from oauthlib.oauth2 import BackendApplicationClient, WebApplicationClient +from requests_oauthlib import OAuth2Session +from typing import Optional, Dict, Any + +class OAuthManager: + """Manages OAuth 2.0 authentication flows.""" + + async def get_access_token( + self, + credentials: Dict[str, Any] + ) -> str: + """Get access token based on grant type.""" + if credentials['grant_type'] == 'client_credentials': + return await self._client_credentials_flow(credentials) + elif credentials['grant_type'] == 'authorization_code': + return await self._authorization_code_flow(credentials) + else: + raise ValueError(f"Unsupported grant type: {credentials['grant_type']}") + + async def _client_credentials_flow( + self, + credentials: Dict[str, Any] + ) -> str: + """Machine-to-machine authentication.""" + client = BackendApplicationClient(client_id=credentials['client_id']) + oauth = OAuth2Session(client=client) + + token = oauth.fetch_token( + token_url=credentials['token_url'], + client_id=credentials['client_id'], + client_secret=credentials['client_secret'], + scope=credentials.get('scopes', []) + ) + + return token['access_token'] + + async def _authorization_code_flow( + self, + credentials: Dict[str, Any] + ) -> Dict[str, str]: + """User delegation flow - returns authorization URL.""" + oauth = OAuth2Session( + credentials['client_id'], + redirect_uri=credentials['redirect_uri'], + scope=credentials.get('scopes', []) + ) + + authorization_url, state = oauth.authorization_url( + credentials['authorization_url'] + ) + + return { + 'authorization_url': authorization_url, + 'state': state + } + + async def exchange_code_for_token( + self, + credentials: Dict[str, Any], + code: str, + state: str + ) -> str: + """Exchange authorization code for access token.""" + oauth = OAuth2Session( + credentials['client_id'], + state=state, + redirect_uri=credentials['redirect_uri'] + ) + + token = oauth.fetch_token( + credentials['token_url'], + client_secret=credentials['client_secret'], + authorization_response=f"{credentials['redirect_uri']}?code={code}&state={state}" + ) + + return token['access_token'] +``` + +### 2. Admin UI OAuth Configuration + +```html +
+

OAuth 2.0 Configuration

+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+
+``` + +## Implementation Details + +### Gateway Service Integration + +**File**: `mcpgateway/services/gateway_service.py` + +```python +async def _initialize_gateway( + self, + url: str, + authentication: Optional[Dict[str, str]] = None, + transport: str = "SSE" +) -> tuple: + """Initialize gateway with OAuth support.""" + + headers = {} + + if authentication and authentication.get('type') == 'oauth': + # Get OAuth credentials from database + gateway = await self._get_gateway(authentication['gateway_id']) + oauth_config = gateway.oauth_config + + # Get access token + access_token = await self.oauth_manager.get_access_token(oauth_config) + headers = {'Authorization': f'Bearer {access_token}'} + else: + # Existing authentication logic + headers = decode_auth(authentication) + + # Connect to MCP server + return await self._connect_to_gateway(url, headers, transport) +``` + +### Tool Service Integration + +**File**: `mcpgateway/services/tool_service.py` + +```python +async def invoke_tool( + self, + db: Session, + name: str, + arguments: Dict[str, Any] +) -> ToolResult: + """Invoke tool with OAuth support.""" + + tool = await self.get_tool_by_name(db, name) + headers = {} + + if tool.gateway and tool.gateway.auth_type == 'oauth': + # Get fresh access token for each request + oauth_config = tool.gateway.oauth_config + access_token = await self.oauth_manager.get_access_token(oauth_config) + headers = {'Authorization': f'Bearer {access_token}'} + else: + # Existing authentication + headers = self._get_tool_headers(tool) + + # Execute tool + return await self._execute_tool(tool, arguments, headers) +``` + +## OAuth Flow Sequences + +### Client Credentials Flow (M2M) + +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant OAuth Manager + participant OAuth Provider + participant MCP Server + + Client->>Gateway: Configure OAuth (Client Credentials) + Client->>Gateway: Invoke Tool + Gateway->>OAuth Manager: Get Access Token + OAuth Manager->>OAuth Provider: POST /token (client_id, secret) + OAuth Provider-->>OAuth Manager: Access Token + OAuth Manager-->>Gateway: Access Token + Gateway->>MCP Server: Tool Request + Bearer Token + MCP Server-->>Gateway: Tool Response + Gateway-->>Client: Result +``` + +### Authorization Code Flow + +```mermaid +sequenceDiagram + participant User + participant Gateway + participant OAuth Manager + participant OAuth Provider + participant MCP Server + + User->>Gateway: Configure OAuth (Auth Code) + User->>Gateway: Request Authorization + Gateway->>OAuth Manager: Get Auth URL + OAuth Manager-->>Gateway: Authorization URL + Gateway-->>User: Redirect to OAuth Provider + User->>OAuth Provider: Login & Authorize + OAuth Provider-->>Gateway: Callback with Code + Gateway->>OAuth Manager: Exchange Code + OAuth Manager->>OAuth Provider: POST /token (code) + OAuth Provider-->>OAuth Manager: Access Token + OAuth Manager-->>Gateway: Access Token + Gateway->>MCP Server: Tool Request + Bearer Token + MCP Server-->>Gateway: Response + Gateway-->>User: Result +``` + +## Security Considerations + +1. **Token Storage**: Access tokens are never stored - requested fresh for each operation +2. **Secret Encryption**: Client secrets encrypted using `AUTH_ENCRYPTION_SECRET` +3. **HTTPS Required**: All OAuth endpoints must use HTTPS +4. **Scope Validation**: Request minimum required scopes +5. **Error Handling**: Comprehensive error handling for OAuth failures + +## Configuration + +### Environment Variables + +```env +# OAuth Configuration +OAUTH_REQUEST_TIMEOUT=30 # OAuth request timeout in seconds +OAUTH_MAX_RETRIES=3 # Max retries for token requests + +# Encryption +AUTH_ENCRYPTION_SECRET=your-secret-key # For encrypting client secrets +``` + +### Example Gateway Configuration + +```json +{ + "name": "GitHub MCP", + "url": "https://github-mcp.example.com/sse", + "auth_type": "oauth", + "oauth_config": { + "grant_type": "authorization_code", + "client_id": "your_github_app_id", + "client_secret": "your_github_app_secret", + "authorization_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "redirect_uri": "https://gateway.example.com/oauth/callback", + "scopes": ["repo", "read:user"] + } +} +``` + +## Implementation Phases + +### Phase 1: Core OAuth Support (Week 1) +- Implement OAuth Manager +- Add database schema changes +- Client Credentials flow + +### Phase 2: UI Integration (Week 2) +- Admin UI OAuth configuration +- Authorization Code flow +- OAuth callback endpoint + +### Phase 3: Testing & Documentation (Week 3) +- Integration tests +- Security review +- User documentation + +## Dependencies + +```toml +# Add to pyproject.toml +dependencies = [ + "oauthlib>=3.2.2", + "requests-oauthlib>=1.3.1", + "cryptography>=41.0.0", # For secret encryption +] +``` + +## Testing + +### Unit Tests + +```python +async def test_client_credentials_flow(): + oauth_manager = OAuthManager() + credentials = { + "grant_type": "client_credentials", + "client_id": "test_client", + "client_secret": "test_secret", + "token_url": "https://oauth.example.com/token" + } + + token = await oauth_manager.get_access_token(credentials) + assert token is not None + assert isinstance(token, str) + +async def test_tool_invocation_with_oauth(): + tool_service = ToolService(oauth_manager) + result = await tool_service.invoke_tool( + db=db, + name="github_create_issue", + arguments={"title": "Test Issue"} + ) + assert result.success +``` + +## Future Enhancements + +1. **OAuth Provider Templates**: Pre-configured settings for common providers +2. **Token Refresh**: Support refresh tokens for long-lived access +3. **PKCE Support**: Add PKCE for public clients +4. **Multiple OAuth Configs**: Support different OAuth configs per tool + +## Conclusion + +This OAuth 2.0 integration provides secure, standards-based authentication for MCP Gateway without the complexity of token caching. By requesting fresh tokens for each operation, we ensure simplicity while maintaining security. The implementation follows OAuth 2.0 best practices and enables seamless integration with various OAuth providers. From 1a007650cd5ea7b954b69b6b5a6c3df409cc679b Mon Sep 17 00:00:00 2001 From: Shamsul Arefin Date: Sat, 16 Aug 2025 20:59:19 +0500 Subject: [PATCH 03/21] Support for oauth auth type in gateway Signed-off-by: Shamsul Arefin --- docs/oauth-setup.md | 179 ++++++++++++++ mcpgateway/admin.py | 88 ++++++- ...c9d3e2a1b4_add_oauth_config_to_gateways.py | 62 +++++ mcpgateway/config.py | 4 + mcpgateway/db.py | 9 +- mcpgateway/schemas.py | 54 +++-- mcpgateway/services/gateway_service.py | 35 ++- mcpgateway/services/oauth_manager.py | 224 ++++++++++++++++++ mcpgateway/services/tool_service.py | 34 ++- mcpgateway/static/admin.js | 87 +++++++ mcpgateway/templates/admin.html | 97 ++++++++ mcpgateway/utils/oauth_encryption.py | 110 +++++++++ pyproject.toml | 2 + tests/unit/mcpgateway/test_oauth_manager.py | 161 +++++++++++++ 14 files changed, 1121 insertions(+), 25 deletions(-) create mode 100644 docs/oauth-setup.md create mode 100644 mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py create mode 100644 mcpgateway/services/oauth_manager.py create mode 100644 mcpgateway/utils/oauth_encryption.py create mode 100644 tests/unit/mcpgateway/test_oauth_manager.py diff --git a/docs/oauth-setup.md b/docs/oauth-setup.md new file mode 100644 index 00000000..70aac8a0 --- /dev/null +++ b/docs/oauth-setup.md @@ -0,0 +1,179 @@ +# OAuth 2.0 Setup Guide for MCP Gateway + +This guide explains how to configure OAuth 2.0 authentication for federated gateways in MCP Gateway. + +## Overview + +MCP Gateway supports OAuth 2.0 authentication for connecting to external MCP servers that require OAuth-based authentication. This eliminates the need to store long-lived personal access tokens and provides secure, scoped access to external services. + +## Supported OAuth Flows + +### 1. Client Credentials Flow (Machine-to-Machine) +- **Use Case**: Server-to-server communication where no user interaction is required +- **Best For**: Automated services, background jobs, API integrations +- **Configuration**: Requires client ID, client secret, and token URL + +### 2. Authorization Code Flow (User Delegation) +- **Use Case**: User-authorized access to external services +- **Best For**: User-specific integrations, services requiring user consent +- **Configuration**: Requires additional authorization URL and redirect URI + +## Configuration + +### Environment Variables + +Add these to your `.env` file: + +```env +# OAuth Configuration +OAUTH_REQUEST_TIMEOUT=30 # OAuth request timeout in seconds +OAUTH_MAX_RETRIES=3 # Max retries for token requests + +# Encryption (for client secrets) +AUTH_ENCRYPTION_SECRET=your-secure-encryption-key # Must be at least 32 characters +``` + +### Gateway Configuration + +When adding a new gateway through the Admin UI: + +1. **Authentication Type**: Select "OAuth 2.0" +2. **Grant Type**: Choose between "Client Credentials" or "Authorization Code" +3. **Client ID**: Your OAuth application's client ID +4. **Client Secret**: Your OAuth application's client secret (will be encrypted) +5. **Token URL**: OAuth provider's token endpoint +6. **Scopes**: Space-separated list of required scopes (e.g., "repo read:user") + +#### Authorization Code Flow Additional Fields + +If using Authorization Code flow, also configure: + +- **Authorization URL**: OAuth provider's authorization endpoint +- **Redirect URI**: Callback URL (typically `https://your-gateway.com/oauth/callback`) + +## Example Configurations + +### GitHub OAuth App + +```json +{ + "grant_type": "authorization_code", + "client_id": "your_github_app_id", + "client_secret": "your_github_app_secret", + "authorization_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "redirect_uri": "https://gateway.example.com/oauth/callback", + "scopes": ["repo", "read:user"] +} +``` + +### Generic OAuth Provider (Client Credentials) + +```json +{ + "grant_type": "client_credentials", + "client_id": "your_client_id", + "client_secret": "your_client_secret", + "token_url": "https://oauth.example.com/token", + "scopes": ["api:read", "api:write"] +} +``` + +## Security Features + +### Client Secret Encryption +- Client secrets are automatically encrypted using AES-256 encryption +- Encryption key derived from `AUTH_ENCRYPTION_SECRET` +- Secrets are never stored in plain text + +### Token Management +- Access tokens are requested fresh for each operation +- No token caching or storage +- Automatic retry with exponential backoff on failures + +### HTTPS Enforcement +- All OAuth endpoints must use HTTPS +- Redirect URIs must use secure protocols + +## Troubleshooting + +### Common Issues + +1. **Invalid Client Credentials** + - Verify client ID and secret are correct + - Ensure OAuth app is properly configured with the provider + +2. **Invalid Redirect URI** + - Check that redirect URI matches exactly what's configured in OAuth app + - Ensure protocol (http/https) matches + +3. **Scope Issues** + - Verify requested scopes are available for your OAuth app + - Check provider's scope documentation + +4. **Network Issues** + - Verify token and authorization URLs are accessible + - Check firewall and network configuration + +### Debug Logging + +Enable debug logging to troubleshoot OAuth issues: + +```env +LOG_LEVEL=DEBUG +``` + +Look for OAuth-related log messages in the gateway logs. + +## API Endpoints + +### OAuth Callback +- **URL**: `GET /oauth/callback` +- **Purpose**: Handle authorization code exchange +- **Parameters**: `code`, `state`, `gateway_id` +- **Authentication**: Required + +## Testing + +### Test OAuth Configuration + +1. Configure OAuth settings in the Admin UI +2. Test the gateway connection +3. Verify tools/resources are accessible +4. Check logs for OAuth-related messages + +### Unit Tests + +Run OAuth unit tests: + +```bash +pytest tests/unit/mcpgateway/test_oauth_manager.py -v +``` + +## Best Practices + +1. **Use Strong Encryption Keys**: Generate a strong `AUTH_ENCRYPTION_SECRET` +2. **Minimal Scopes**: Request only the scopes you actually need +3. **Secure Storage**: Keep encryption keys secure and rotate regularly +4. **Monitor Usage**: Watch for unusual OAuth activity +5. **Regular Testing**: Test OAuth flows regularly to ensure they work + +## Migration from Personal Access Tokens + +To migrate from PAT-based authentication to OAuth: + +1. Create OAuth app with your service provider +2. Configure OAuth settings in gateway +3. Test connection and functionality +4. Remove old PAT-based configuration +5. Update any hardcoded authentication references + +## Support + +For OAuth-related issues: + +1. Check the troubleshooting section above +2. Review gateway logs for error messages +3. Verify OAuth provider configuration +4. Test with a simple OAuth client first +5. Check provider's OAuth documentation diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 41a61280..2a6b9e60 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -34,12 +34,13 @@ import httpx from pydantic import ValidationError from pydantic_core import ValidationError as CoreValidationError +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session # First-Party from mcpgateway.config import settings -from mcpgateway.db import get_db, GlobalConfig +from mcpgateway.db import get_db, GlobalConfig, Gateway as DbGateway from mcpgateway.models import LogLevel from mcpgateway.schemas import ( GatewayCreate, @@ -2613,6 +2614,21 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use except (json.JSONDecodeError, ValueError): auth_headers = [] + # Parse OAuth configuration if present + oauth_config_json = str(form.get("oauth_config")) + oauth_config: Optional[dict[str, Any]] = None + if oauth_config_json and oauth_config_json != "None": + try: + oauth_config = json.loads(oauth_config_json) + # Encrypt the client secret if present + if oauth_config and "client_secret" in oauth_config: + from mcpgateway.utils.oauth_encryption import get_oauth_encryption + encryption = get_oauth_encryption(settings.auth_encryption_secret) + oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) + except (json.JSONDecodeError, ValueError) as e: + LOGGER.error(f"Failed to parse OAuth config: {e}") + oauth_config = None + # Handle passthrough_headers passthrough_headers = str(form.get("passthrough_headers")) if passthrough_headers and passthrough_headers.strip(): @@ -2637,6 +2653,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use auth_header_key=str(form.get("auth_header_key", "")), auth_header_value=str(form.get("auth_header_value", "")), auth_headers=auth_headers if auth_headers else None, + oauth_config=oauth_config, passthrough_headers=passthrough_headers, ) except KeyError as e: @@ -2669,6 +2686,75 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) +@admin_router.get("/oauth/callback") +async def oauth_callback( + code: str, + state: str, + gateway_id: str, + request: Request, + db: Session = Depends(get_db), + user: str = Depends(require_auth) +) -> JSONResponse: + """Handle OAuth authorization code callback. + + Args: + code: Authorization code from OAuth provider + state: State parameter for CSRF protection + gateway_id: ID of the gateway being configured + request: FastAPI request + db: Database session + user: Authenticated user + + Returns: + JSON response indicating success or failure + """ + try: + # Get the gateway + gateway = db.execute( + select(DbGateway).where(DbGateway.id == gateway_id) + ).scalar_one_or_none() + + if not gateway: + return JSONResponse( + content={"success": False, "message": "Gateway not found"}, + status_code=404 + ) + + if not gateway.oauth_config: + return JSONResponse( + content={"success": False, "message": "Gateway has no OAuth configuration"}, + status_code=400 + ) + + # Exchange authorization code for access token + from mcpgateway.services.oauth_manager import OAuthManager + oauth_manager = OAuthManager() + + access_token = await oauth_manager.exchange_code_for_token( + gateway.oauth_config, + code, + state + ) + + # Store the access token temporarily (in production, you might want to store this securely) + # For now, we'll just return success + return JSONResponse( + content={ + "success": True, + "message": "OAuth authorization successful", + "access_token": access_token[:10] + "..." # Show first 10 chars for verification + }, + status_code=200 + ) + + except Exception as e: + LOGGER.error(f"OAuth callback failed: {e}") + return JSONResponse( + content={"success": False, "message": f"OAuth callback failed: {str(e)}"}, + status_code=500 + ) + + @admin_router.post("/gateways/{gateway_id}/edit") async def admin_edit_gateway( gateway_id: str, diff --git a/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py b/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py new file mode 100644 index 00000000..472466c6 --- /dev/null +++ b/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +"""add oauth config to gateways + +Revision ID: f8c9d3e2a1b4 +Revises: eb17fd368f9d +Create Date: 2024-12-20 10:00:00.000000 + +""" + +# Standard +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision: str = "f8c9d3e2a1b4" +down_revision: Union[str, Sequence[str], None] = "eb17fd368f9d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add oauth_config column to gateways table.""" + # Check if we're dealing with a fresh database + inspector = sa.inspect(op.get_bind()) + tables = inspector.get_table_names() + + if "gateways" not in tables: + print("Fresh database detected. Skipping migration.") + return + + # Add oauth_config column + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "oauth_config", + sa.JSON(), + nullable=True, + comment="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes" + ) + ) + + print("Successfully added oauth_config column to gateways table.") + + +def downgrade() -> None: + """Remove oauth_config column from gateways table.""" + # Check if we're dealing with a fresh database + inspector = sa.inspect(op.get_bind()) + tables = inspector.get_table_names() + + if "gateways" not in tables: + print("Fresh database detected. Skipping migration.") + return + + # Remove oauth_config column + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.drop_column("oauth_config") + + print("Successfully removed oauth_config column from gateways table.") diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 21eebd98..00ffe35e 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -146,6 +146,10 @@ class Settings(BaseSettings): # Encryption key phrase for auth storage auth_encryption_secret: str = "my-test-salt" + # OAuth Configuration + oauth_request_timeout: int = Field(default=30, description="OAuth request timeout in seconds") + oauth_max_retries: int = Field(default=3, description="Maximum retries for OAuth token requests") + # UI/Admin Feature Flags mcpgateway_ui_enabled: bool = False mcpgateway_admin_api_enabled: bool = False diff --git a/mcpgateway/db.py b/mcpgateway/db.py index e1f0573a..8ee2e884 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -1130,9 +1130,16 @@ class Gateway(Base): # federated_prompts: Mapped[List["Prompt"]] = relationship(secondary=prompt_gateway_table, back_populates="federated_with") # Authorizations - auth_type: Mapped[Optional[str]] = mapped_column(default=None) # "basic", "bearer", "headers" or None + auth_type: Mapped[Optional[str]] = mapped_column(default=None) # "basic", "bearer", "headers", "oauth" or None auth_value: Mapped[Optional[Dict[str, str]]] = mapped_column(JSON) + # OAuth configuration + oauth_config: Mapped[Optional[Dict[str, Any]]] = mapped_column( + JSON, + nullable=True, + comment="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes" + ) + @event.listens_for(Gateway, "after_update") def update_tool_names_on_gateway_update(_mapper, connection, target): diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 9f0a68c1..14cc023e 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -1799,7 +1799,7 @@ class GatewayCreate(BaseModel): passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations - auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, or none") + auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, oauth, or none") # Fields for various types of authentication auth_username: Optional[str] = Field(None, description="Username for basic authentication") auth_password: Optional[str] = Field(None, description="Password for basic authentication") @@ -1808,6 +1808,9 @@ class GatewayCreate(BaseModel): auth_header_value: Optional[str] = Field(None, description="Value for custom headers authentication") auth_headers: Optional[List[Dict[str, str]]] = Field(None, description="List of custom headers for authentication") + # OAuth 2.0 configuration + oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes") + # Adding `auth_value` as an alias for better access post-validation auth_value: Optional[str] = Field(None, validate_default=True) tags: Optional[List[str]] = Field(default_factory=list, description="Tags for categorizing the gateway") @@ -1958,6 +1961,12 @@ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]: return encode_auth({"Authorization": f"Bearer {token}"}) + if auth_type == "oauth": + # For OAuth authentication, we don't encode anything here + # The OAuth configuration is handled separately in the oauth_config field + # This method is only called for traditional auth types + return None + if auth_type == "authheaders": # Support both new multi-headers format and legacy single header format auth_headers = data.get("auth_headers") @@ -2011,7 +2020,7 @@ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]: return encode_auth({header_key: header_value}) - raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, or headers.") + raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, oauth, or headers.") class GatewayUpdate(BaseModelWithConfigDict): @@ -2126,13 +2135,13 @@ def create_auth_value(cls, v, info): return auth_value @staticmethod - def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]: """ Processes the input authentication fields and returns the correct auth_value. This method is called based on the selected auth_type. Args: - values: Dict container auth information auth_type, auth_username, auth_password, auth_token, auth_header_key and auth_header_value + info: ValidationInfo containing auth fields Returns: dict: Encoded auth information @@ -2141,12 +2150,13 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]: ValueError: If auth type is invalid """ - auth_type = values.data.get("auth_type") + data = info.data + auth_type = data.get("auth_type") if auth_type == "basic": # For basic authentication, both username and password must be present - username = values.data.get("auth_username") - password = values.data.get("auth_password") + username = data.get("auth_username") + password = data.get("auth_password") if not username or not password: raise ValueError("For 'basic' auth, both 'auth_username' and 'auth_password' must be provided.") @@ -2155,16 +2165,22 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]: if auth_type == "bearer": # For bearer authentication, only token is required - token = values.data.get("auth_token") + token = data.get("auth_token") if not token: raise ValueError("For 'bearer' auth, 'auth_token' must be provided.") return encode_auth({"Authorization": f"Bearer {token}"}) + if auth_type == "oauth": + # For OAuth authentication, we don't encode anything here + # The OAuth configuration is handled separately in the oauth_config field + # This method is only called for traditional auth types + return None + if auth_type == "authheaders": # Support both new multi-headers format and legacy single header format - auth_headers = values.data.get("auth_headers") + auth_headers = data.get("auth_headers") if auth_headers and isinstance(auth_headers, list): # New multi-headers format with enhanced validation header_dict = {} @@ -2207,15 +2223,15 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]: return encode_auth(header_dict) # Legacy single header format (backward compatibility) - header_key = values.data.get("auth_header_key") - header_value = values.data.get("auth_header_value") + header_key = data.get("auth_header_key") + header_value = data.get("auth_header_value") if not header_key or not header_value: raise ValueError("For 'headers' auth, either 'auth_headers' list or both 'auth_header_key' and 'auth_header_value' must be provided.") return encode_auth({header_key: header_value}) - raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, or headers.") + raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, oauth, or headers.") class GatewayRead(BaseModelWithConfigDict): @@ -2228,8 +2244,9 @@ class GatewayRead(BaseModelWithConfigDict): - enabled status - reachable status - Last seen timestamp - - Authentication type: basic, bearer, headers + - Authentication type: basic, bearer, headers, oauth - Authentication value: username/password or token or custom headers + - OAuth configuration for OAuth 2.0 authentication Auto Populated fields: - Authentication username: for basic auth @@ -2254,9 +2271,12 @@ class GatewayRead(BaseModelWithConfigDict): passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations - auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None") + auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers, oauth, or None") auth_value: Optional[str] = Field(None, description="auth value: username/password or token or custom headers") + # OAuth 2.0 configuration + oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes") + # auth_value will populate the following fields auth_username: Optional[str] = Field(None, description="username for basic authentication") auth_password: Optional[str] = Field(None, description="password for basic authentication") @@ -2341,6 +2361,12 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]: if auth_value_encoded == settings.masked_auth_value: return values + # Handle OAuth authentication (no auth_value to decode) + if auth_type == "oauth": + # OAuth gateways don't have traditional auth_value to decode + # They use oauth_config instead + return values + auth_value = decode_auth(auth_value_encoded) if auth_type == "basic": auth = auth_value.get("Authorization") diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index e85dced2..0f423e13 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -76,6 +76,7 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.observability import create_span from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate +from mcpgateway.services.oauth_manager import OAuthManager # logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks from mcpgateway.services.logging_service import LoggingService @@ -230,6 +231,10 @@ def __init__(self) -> None: self._pending_responses = {} self.tool_service = ToolService() self._gateway_failure_counts: dict[str, int] = {} + self.oauth_manager = OAuthManager( + request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), + max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3")) + ) # For health checks, we determine the leader instance. self.redis_url = settings.redis_url if settings.cache_type == "redis" else None @@ -447,7 +452,10 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")} auth_value = encode_auth(header_dict) # Encode the dict for consistency - capabilities, tools, resources, prompts = await self._initialize_gateway(normalized_url, auth_value, gateway.transport) + oauth_config = getattr(gateway, "oauth_config", None) + capabilities, tools, resources, prompts = await self._initialize_gateway( + normalized_url, auth_value, gateway.transport, auth_type, oauth_config + ) tools = [ DbTool( @@ -502,6 +510,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway last_seen=datetime.now(timezone.utc), auth_type=auth_type, auth_value=auth_value, + oauth_config=oauth_config, tools=tools, resources=db_resources, prompts=db_prompts, @@ -671,7 +680,10 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat # Try to reinitialize connection if URL changed if gateway_update.url is not None: try: - capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) + capabilities, tools, resources, prompts = await self._initialize_gateway( + gateway.url, gateway.auth_value, gateway.transport, + gateway.auth_type, gateway.oauth_config + ) new_tool_names = [tool.name for tool in tools] new_resource_uris = [resource.uri for resource in resources] new_prompt_names = [prompt.name for prompt in prompts] @@ -858,7 +870,10 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo self._active_gateways.add(gateway.url) # Try to initialize if activating try: - capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) + capabilities, tools, resources, prompts = await self._initialize_gateway( + gateway.url, gateway.auth_value, gateway.transport, + gateway.auth_type, gateway.oauth_config + ) new_tool_names = [tool.name for tool in tools] new_resource_uris = [resource.uri for resource in resources] new_prompt_names = [prompt.name for prompt in prompts] @@ -1371,7 +1386,8 @@ async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: self._event_subscribers.remove(queue) async def _initialize_gateway( - self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE" + self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE", + auth_type: Optional[str] = None, oauth_config: Optional[Dict[str, Any]] = None ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: """Initialize connection to a gateway and retrieve its capabilities. @@ -1383,6 +1399,8 @@ async def _initialize_gateway( url: Gateway URL to connect to authentication: Optional authentication headers for the connection transport: Transport protocol - "SSE" or "StreamableHTTP" + auth_type: Authentication type - "basic", "bearer", "headers", "oauth" or None + oauth_config: OAuth configuration if auth_type is "oauth" Returns: tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: @@ -1418,6 +1436,15 @@ async def _initialize_gateway( if authentication is None: authentication = {} + # Handle OAuth authentication + if auth_type == "oauth" and oauth_config: + try: + access_token = await self.oauth_manager.get_access_token(oauth_config) + authentication = {"Authorization": f"Bearer {access_token}"} + except Exception as e: + logger.error(f"Failed to obtain OAuth access token: {e}") + raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}") + async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None): """Connect to an MCP server running with SSE transport. diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py new file mode 100644 index 00000000..1c78a851 --- /dev/null +++ b/mcpgateway/services/oauth_manager.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +"""OAuth 2.0 Manager for MCP Gateway. + +This module handles OAuth 2.0 authentication flows including: +- Client Credentials (Machine-to-Machine) +- Authorization Code (User Delegation) +""" + +import logging +from typing import Dict, Any, Optional +from oauthlib.oauth2 import BackendApplicationClient, WebApplicationClient +from requests_oauthlib import OAuth2Session +import aiohttp +import asyncio + +logger = logging.getLogger(__name__) + + +class OAuthManager: + """Manages OAuth 2.0 authentication flows.""" + + def __init__(self, request_timeout: int = 30, max_retries: int = 3): + """Initialize OAuth Manager. + + Args: + request_timeout: Timeout for OAuth requests in seconds + max_retries: Maximum number of retry attempts for token requests + """ + self.request_timeout = request_timeout + self.max_retries = max_retries + + async def get_access_token( + self, + credentials: Dict[str, Any] + ) -> str: + """Get access token based on grant type. + + Args: + credentials: OAuth configuration containing grant_type and other params + + Returns: + Access token string + + Raises: + ValueError: If grant type is unsupported + OAuthError: If token acquisition fails + """ + grant_type = credentials.get('grant_type') + print(f"grant_type: {grant_type}") + + if grant_type == 'client_credentials': + return await self._client_credentials_flow(credentials) + elif grant_type == 'authorization_code': + # For authorization code flow, this method should not be called directly + # Use get_authorization_url and exchange_code_for_token instead + raise ValueError( + "Authorization code flow requires calling get_authorization_url first" + ) + else: + raise ValueError(f"Unsupported grant type: {grant_type}") + + async def _client_credentials_flow( + self, + credentials: Dict[str, Any] + ) -> str: + """Machine-to-machine authentication using client credentials. + + Args: + credentials: OAuth configuration with client_id, client_secret, token_url + + Returns: + Access token string + """ + client_id = credentials['client_id'] + client_secret = credentials['client_secret'] + token_url = credentials['token_url'] + scopes = credentials.get('scopes', []) + + # Create OAuth2 session with backend application client + client = BackendApplicationClient(client_id=client_id) + oauth = OAuth2Session(client=client) + + # Prepare token request data + token_data = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': client_secret, + } + + if scopes: + token_data['scope'] = ' '.join(scopes) if isinstance(scopes, list) else scopes + + # Fetch token with retries + for attempt in range(self.max_retries): + try: + async with aiohttp.ClientSession() as session: + async with session.post( + token_url, + data=token_data, + timeout=aiohttp.ClientTimeout(total=self.request_timeout) + ) as response: + response.raise_for_status() + token_response = await response.json() + + if 'access_token' not in token_response: + raise OAuthError( + f"No access_token in response: {token_response}" + ) + + logger.info( + f"Successfully obtained access token via client credentials" + ) + return token_response['access_token'] + + except aiohttp.ClientError as e: + logger.warning( + f"Token request attempt {attempt + 1} failed: {str(e)}" + ) + if attempt == self.max_retries - 1: + raise OAuthError( + f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}" + ) + await asyncio.sleep(2 ** attempt) # Exponential backoff + + async def get_authorization_url( + self, + credentials: Dict[str, Any] + ) -> Dict[str, str]: + """Get authorization URL for user delegation flow. + + Args: + credentials: OAuth configuration with client_id, authorization_url, etc. + + Returns: + Dict containing authorization_url and state + """ + client_id = credentials['client_id'] + redirect_uri = credentials['redirect_uri'] + authorization_url = credentials['authorization_url'] + scopes = credentials.get('scopes', []) + + # Create OAuth2 session + oauth = OAuth2Session( + client_id, + redirect_uri=redirect_uri, + scope=scopes + ) + + # Generate authorization URL with state for CSRF protection + auth_url, state = oauth.authorization_url(authorization_url) + + logger.info(f"Generated authorization URL for client {client_id}") + + return { + 'authorization_url': auth_url, + 'state': state + } + + async def exchange_code_for_token( + self, + credentials: Dict[str, Any], + code: str, + state: str + ) -> str: + """Exchange authorization code for access token. + + Args: + credentials: OAuth configuration + code: Authorization code from callback + state: State parameter for CSRF validation + + Returns: + Access token string + """ + client_id = credentials['client_id'] + client_secret = credentials['client_secret'] + token_url = credentials['token_url'] + redirect_uri = credentials['redirect_uri'] + + # Prepare token exchange data + token_data = { + 'grant_type': 'authorization_code', + 'code': code, + 'redirect_uri': redirect_uri, + 'client_id': client_id, + 'client_secret': client_secret, + } + + # Exchange code for token with retries + for attempt in range(self.max_retries): + try: + async with aiohttp.ClientSession() as session: + async with session.post( + token_url, + data=token_data, + timeout=aiohttp.ClientTimeout(total=self.request_timeout) + ) as response: + response.raise_for_status() + token_response = await response.json() + + if 'access_token' not in token_response: + raise OAuthError( + f"No access_token in response: {token_response}" + ) + + logger.info( + f"Successfully exchanged authorization code for access token" + ) + return token_response['access_token'] + + except aiohttp.ClientError as e: + logger.warning( + f"Token exchange attempt {attempt + 1} failed: {str(e)}" + ) + if attempt == self.max_retries - 1: + raise OAuthError( + f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}" + ) + await asyncio.sleep(2 ** attempt) # Exponential backoff + + +class OAuthError(Exception): + """OAuth-related errors.""" + pass diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 65673c83..7b25fabb 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -44,6 +44,7 @@ from mcpgateway.plugins.framework.plugin_types import GlobalContext, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.passthrough_headers import get_passthrough_headers @@ -167,6 +168,10 @@ def __init__(self) -> None: self._event_subscribers: List[asyncio.Queue] = [] self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) self._plugin_manager: PluginManager | None = PluginManager() if settings.plugins_enabled else None + self.oauth_manager = OAuthManager( + request_timeout=int(settings.oauth_request_timeout if hasattr(settings, 'oauth_request_timeout') else 30), + max_retries=int(settings.oauth_max_retries if hasattr(settings, 'oauth_max_retries') else 3) + ) async def initialize(self) -> None: """Initialize the service. @@ -706,10 +711,19 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r # headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers) headers = tool.headers or {} if tool.integration_type == "REST": - credentials = decode_auth(tool.auth_value) - # Filter out empty header names/values to avoid "Illegal header name" errors - filtered_credentials = {k: v for k, v in credentials.items() if k and v} - headers.update(filtered_credentials) + # Handle OAuth authentication for REST tools + if tool.auth_type == "oauth" and hasattr(tool, 'oauth_config') and tool.oauth_config: + try: + access_token = await self.oauth_manager.get_access_token(tool.oauth_config) + headers["Authorization"] = f"Bearer {access_token}" + except Exception as e: + logger.error(f"Failed to obtain OAuth access token for tool {tool.name}: {e}") + raise ToolInvocationError(f"OAuth authentication failed: {str(e)}") + else: + credentials = decode_auth(tool.auth_value) + # Filter out empty header names/values to avoid "Illegal header name" errors + filtered_credentials = {k: v for k, v in credentials.items() if k and v} + headers.update(filtered_credentials) # Only call get_passthrough_headers if we actually have request headers to pass through if request_headers: @@ -761,7 +775,17 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r elif tool.integration_type == "MCP": transport = tool.request_type.lower() gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() - headers = decode_auth(gateway.auth_value if gateway else None) + + # Handle OAuth authentication for the gateway + if gateway and gateway.auth_type == "oauth" and gateway.oauth_config: + try: + access_token = await self.oauth_manager.get_access_token(gateway.oauth_config) + headers = {"Authorization": f"Bearer {access_token}"} + except Exception as e: + logger.error(f"Failed to obtain OAuth access token for gateway {gateway.name}: {e}") + raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") + else: + headers = decode_auth(gateway.auth_value if gateway else None) # Get combined headers including gateway auth and passthrough if request_headers: diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index f6dfa818..cdd3f3a5 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -5113,6 +5113,35 @@ async function handleGatewayFormSubmit(e) { } } + // Handle OAuth configuration + const authType = formData.get("auth_type"); + if (authType === "oauth") { + const oauthConfig = { + grant_type: formData.get("oauth_grant_type"), + client_id: formData.get("oauth_client_id"), + client_secret: formData.get("oauth_client_secret"), + token_url: formData.get("oauth_token_url"), + scopes: formData.get("oauth_scopes") ? formData.get("oauth_scopes").split(" ").filter(s => s.trim()) : [] + }; + + // Add authorization code specific fields + if (oauthConfig.grant_type === "authorization_code") { + oauthConfig.authorization_url = formData.get("oauth_authorization_url"); + oauthConfig.redirect_uri = formData.get("oauth_redirect_uri"); + } + + // Remove individual OAuth fields and add as oauth_config + formData.delete("oauth_grant_type"); + formData.delete("oauth_client_id"); + formData.delete("oauth_client_secret"); + formData.delete("oauth_token_url"); + formData.delete("oauth_scopes"); + formData.delete("oauth_authorization_url"); + formData.delete("oauth_redirect_uri"); + + formData.append("oauth_config", JSON.stringify(oauthConfig)); + } + const response = await fetchWithTimeout( `${window.ROOT_PATH}/admin/gateways`, { @@ -6170,6 +6199,18 @@ function setupFormHandlers() { const gatewayForm = safeGetElement("add-gateway-form"); if (gatewayForm) { gatewayForm.addEventListener("submit", handleGatewayFormSubmit); + + // Add OAuth authentication type change handler + const authTypeField = safeGetElement("auth-type-gw"); + if (authTypeField) { + authTypeField.addEventListener("change", handleAuthTypeChange); + } + + // Add OAuth grant type change handler + const oauthGrantTypeField = safeGetElement("oauth-grant-type-gw"); + if (oauthGrantTypeField) { + oauthGrantTypeField.addEventListener("change", handleOAuthGrantTypeChange); + } } const resourceForm = safeGetElement("add-resource-form"); @@ -6253,6 +6294,52 @@ function setupFormHandlers() { } } +function handleAuthTypeChange() { + const authType = this.value; + const basicFields = safeGetElement("auth-basic-fields-gw"); + const bearerFields = safeGetElement("auth-bearer-fields-gw"); + const headersFields = safeGetElement("auth-headers-fields-gw"); + const oauthFields = safeGetElement("auth-oauth-fields-gw"); + + // Hide all auth sections first + if (basicFields) basicFields.style.display = "none"; + if (bearerFields) bearerFields.style.display = "none"; + if (headersFields) headersFields.style.display = "none"; + if (oauthFields) oauthFields.style.display = "none"; + + // Show the appropriate section + switch (authType) { + case "basic": + if (basicFields) basicFields.style.display = "block"; + break; + case "bearer": + if (bearerFields) bearerFields.style.display = "block"; + break; + case "authheaders": + if (headersFields) headersFields.style.display = "block"; + break; + case "oauth": + if (oauthFields) oauthFields.style.display = "block"; + break; + default: + // No auth - keep everything hidden + break; + } +} + +function handleOAuthGrantTypeChange() { + const grantType = this.value; + const authCodeFields = safeGetElement("oauth-auth-code-fields-gw"); + + if (authCodeFields) { + if (grantType === "authorization_code") { + authCodeFields.style.display = "block"; + } else { + authCodeFields.style.display = "none"; + } + } +} + function setupSchemaModeHandlers() { const schemaModeRadios = document.getElementsByName("schema_input_mode"); const uiBuilderDiv = safeGetElement("ui-builder"); diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 8296b545..5e7e883f 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -2277,6 +2277,7 @@

+ + + +
+ + + {% if gateway.authType == 'oauth' and gateway.oauthConfig %} +
+ OAuth: + {{ gateway.oauthConfig.grant_type.replace('_', ' ').title() }} + {% if gateway.oauthConfig.grant_type == 'authorization_code' %} +
+ 🔐 User Delegation Enabled + {% endif %} +
+ {% endif %} > Test + + + + + + + + + {% if gateway.authType == 'oauth' %} + + 🔐 Authorize + + + + + {% endif %} {% if gateway.enabled %}
class="mt-1 block w-full rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" placeholder="https://oauth.example.com/authorize" /> +

+ The OAuth provider's authorization endpoint URL +

@@ -2414,6 +2459,69 @@

class="mt-1 block w-full rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" placeholder="https://gateway.example.com/oauth/callback" /> +

+ This must match the redirect URI configured in your OAuth application +

+

+ 💡 Use: {{ request.base_url }}oauth/callback +

+

+ +
+ +
+ + +
+

+ Token management options for Authorization Code flow +

+
+ +
+
+
+ + + +
+
+

+ Authorization Code Flow Setup +

+
+

After creating this gateway, you'll need to:

+
    +
  1. Click the "🔐 Authorize" button in the gateway list
  2. +
  3. Complete the OAuth consent flow with your provider
  4. +
  5. Return to the admin panel to see authorization status
  6. +
+

Note: The gateway will be created but tools won't work until OAuth authorization is completed.

+
+
+
From 5421a9466b78bfdc0ad642f8f1735a0577e41693 Mon Sep 17 00:00:00 2001 From: Shamsul Arefin Date: Sun, 17 Aug 2025 22:40:00 +0500 Subject: [PATCH 06/21] test fixes Signed-off-by: Shamsul Arefin --- tests/unit/mcpgateway/test_oauth_manager.py | 108 ++++++++------------ 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 32b4dc91..d90d2845 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -33,12 +33,32 @@ async def test_get_access_token_client_credentials_success(self): "scopes": ["read", "write"] } - with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session: - mock_session.return_value.__aenter__.return_value.post.return_value.__aenter__.return_value = MagicMock( - status=200, - json=AsyncMock(return_value={"access_token": "test_token_123"}), - raise_for_status=MagicMock() - ) + with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session_class: + # Create mock session instance + mock_session_instance = MagicMock() + + # Create mock post method + mock_post = MagicMock() + mock_session_instance.post = mock_post + + # Create mock response + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"access_token": "test_token_123"}) + mock_response.raise_for_status = MagicMock() + + # Async context manager for response + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Post returns response + mock_post.return_value = mock_response + + # Session instance context manager + mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance) + mock_session_instance.__aexit__ = AsyncMock(return_value=None) + + mock_session_class.return_value = mock_session_instance result = await manager.get_access_token(credentials) assert result == "test_token_123" @@ -57,20 +77,6 @@ async def test_get_access_token_unsupported_grant_type(self): with pytest.raises(ValueError, match="Unsupported grant type: unsupported"): await manager.get_access_token(credentials) - @pytest.mark.asyncio - async def test_get_access_token_authorization_code_error(self): - """Test error when calling get_access_token with authorization code flow.""" - manager = OAuthManager() - credentials = { - "grant_type": "authorization_code", - "client_id": "test_client", - "client_secret": "test_secret", - "token_url": "https://oauth.example.com/token" - } - - with pytest.raises(ValueError, match="Authorization code flow requires calling get_authorization_url first"): - await manager.get_access_token(credentials) - @pytest.mark.asyncio async def test_get_authorization_url_success(self): """Test successful authorization URL generation.""" @@ -100,59 +106,27 @@ async def test_exchange_code_for_token_success(self): "redirect_uri": "https://gateway.example.com/callback" } - with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session: - mock_session.return_value.__aenter__.return_value.post.return_value.__aenter__.return_value = MagicMock( - status=200, - json=AsyncMock(return_value={"access_token": "exchanged_token_456"}), - raise_for_status=MagicMock() - ) + with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session_class: + mock_session_instance = MagicMock() + mock_post = MagicMock() + mock_session_instance.post = mock_post - result = await manager.exchange_code_for_token(credentials, "auth_code_123", "state_456") - assert result == "exchanged_token_456" - - @pytest.mark.asyncio - async def test_client_credentials_flow_retry_on_failure(self): - """Test retry mechanism for client credentials flow.""" - manager = OAuthManager(max_retries=2) - credentials = { - "grant_type": "client_credentials", - "client_id": "test_client", - "client_secret": "test_secret", - "token_url": "https://oauth.example.com/token" - } - - with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session: - # First attempt fails, second succeeds mock_response = MagicMock() - mock_response.status = 500 - mock_response.raise_for_status.side_effect = [Exception("First failure"), None] - mock_response.json = AsyncMock(return_value={"access_token": "retry_token"}) - - mock_session.return_value.__aenter__.return_value.post.return_value.__aenter__.return_value = mock_response - - result = await manager.get_access_token(credentials) - assert result == "retry_token" + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"access_token": "exchanged_token_456"}) + mock_response.raise_for_status = MagicMock() + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) - @pytest.mark.asyncio - async def test_client_credentials_flow_max_retries_exceeded(self): - """Test that max retries are respected.""" - manager = OAuthManager(max_retries=2) - credentials = { - "grant_type": "client_credentials", - "client_id": "test_client", - "client_secret": "test_secret", - "token_url": "https://oauth.example.com/token" - } + mock_post.return_value = mock_response - with patch('mcpgateway.services.oauth_manager.aiohttp.ClientSession') as mock_session: - mock_response = MagicMock() - mock_response.status = 500 - mock_response.raise_for_status.side_effect = Exception("Always fails") + mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance) + mock_session_instance.__aexit__ = AsyncMock(return_value=None) + mock_session_class.return_value = mock_session_instance - mock_session.return_value.__aenter__.return_value.post.return_value.__aenter__.return_value = mock_response + result = await manager.exchange_code_for_token(credentials, "auth_code_123", "state_456") + assert result == "exchanged_token_456" - with pytest.raises(OAuthError, match="Failed to obtain access token after 2 attempts"): - await manager.get_access_token(credentials) def test_oauth_error_inheritance(self): """Test that OAuthError inherits from Exception.""" From 3d8fd596b9662bce0d09991a52140a145953433d Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Sat, 16 Aug 2025 17:32:59 +0100 Subject: [PATCH 07/21] 256 fuzz testing (#760) * Implement comprehensive fuzz testing automation (#256) - Add property-based testing with Hypothesis for JSON-RPC, JSONPath, and schema validation - Add coverage-guided fuzzing with Atheris for deep code path exploration - Add API endpoint fuzzing with Schemathesis for contract validation - Add security-focused testing for vulnerability discovery (SQL injection, XSS, etc.) - Add complete Makefile automation with fuzz-all, fuzz-quick, fuzz-extended targets - Add optional [fuzz] dependency group in pyproject.toml for clean installation - Add comprehensive reporting with JSON/Markdown outputs and executive summaries - Add complete developer documentation with examples and troubleshooting guides - Exclude fuzz tests from main test suite to prevent auth failures - Found multiple real bugs in JSON-RPC validation during development Signed-off-by: Mihai Criveti * Update fuzz testing Signed-off-by: Mihai Criveti * Update fuzz testing Signed-off-by: Mihai Criveti --------- Signed-off-by: Mihai Criveti --- .github/workflows/pytest.yml | 1 + .gitignore | 5 + .pre-commit-config.yaml | 2 +- Makefile | 111 ++- docs/docs/testing/fuzzing.md | 742 +++++++++++++++++++++ mcpgateway/main.py | 2 +- pyproject.toml | 34 +- tests/fuzz/__init__.py | 2 + tests/fuzz/conftest.py | 37 + tests/fuzz/fuzzers/fuzz_config_parser.py | 191 ++++++ tests/fuzz/fuzzers/fuzz_jsonpath.py | 154 +++++ tests/fuzz/fuzzers/fuzz_jsonrpc.py | 185 +++++ tests/fuzz/scripts/generate_fuzz_report.py | 391 +++++++++++ tests/fuzz/scripts/run_restler_docker.py | 144 ++++ tests/fuzz/test_api_schema_fuzz.py | 189 ++++++ tests/fuzz/test_jsonpath_fuzz.py | 279 ++++++++ tests/fuzz/test_jsonrpc_fuzz.py | 372 +++++++++++ tests/fuzz/test_schema_validation_fuzz.py | 448 +++++++++++++ tests/fuzz/test_security_fuzz.py | 434 ++++++++++++ 19 files changed, 3710 insertions(+), 13 deletions(-) create mode 100644 docs/docs/testing/fuzzing.md create mode 100644 tests/fuzz/__init__.py create mode 100644 tests/fuzz/conftest.py create mode 100755 tests/fuzz/fuzzers/fuzz_config_parser.py create mode 100755 tests/fuzz/fuzzers/fuzz_jsonpath.py create mode 100755 tests/fuzz/fuzzers/fuzz_jsonrpc.py create mode 100755 tests/fuzz/scripts/generate_fuzz_report.py create mode 100755 tests/fuzz/scripts/run_restler_docker.py create mode 100644 tests/fuzz/test_api_schema_fuzz.py create mode 100644 tests/fuzz/test_jsonpath_fuzz.py create mode 100644 tests/fuzz/test_jsonrpc_fuzz.py create mode 100644 tests/fuzz/test_schema_validation_fuzz.py create mode 100644 tests/fuzz/test_security_fuzz.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 0054992b..2b96ed62 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -75,6 +75,7 @@ jobs: - name: 🧪 Run pytest run: | pytest \ + --ignore=tests/fuzz \ --cov=mcpgateway \ --cov-report=xml \ --cov-report=html \ diff --git a/.gitignore b/.gitignore index 15002e63..95e6ea2f 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,11 @@ FIXMEs *.old logs/ *.log + +# Fuzzing artifacts and reports +reports/ +corpus/ +tests/fuzz/fuzzers/results/ .venv mcp.db public/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 526bc7f2..af1d059a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -368,7 +368,7 @@ repos: description: Verifies test files in tests/ directories start with `test_`. language: python files: (^|/)tests/.+\.py$ - exclude: ^tests/.*/(pages|helpers)/.*\.py$ # Exclude page object and helper files + exclude: ^tests/.*/(pages|helpers|fuzzers|scripts)/.*\.py$ # Exclude page object, helper, fuzzer, and script files args: [--pytest-test-first] # `test_.*\.py` # - repo: https://github.com/pycqa/flake8 diff --git a/Makefile b/Makefile index 82abb845..f5a847e0 100644 --- a/Makefile +++ b/Makefile @@ -234,7 +234,7 @@ test: @test -d "$(VENV_DIR)" || $(MAKE) venv @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ python3 -m pip install -q pytest pytest-asyncio pytest-cov && \ - python3 -m pytest --maxfail=0 --disable-warnings -v" + python3 -m pytest --maxfail=0 --disable-warnings -v --ignore=tests/fuzz" coverage: @test -d "$(VENV_DIR)" || $(MAKE) venv @@ -4355,3 +4355,112 @@ pre-commit-check-headers: ## 🪝 Check headers for pre-commit hooks pre-commit-fix-headers: ## 🪝 Fix headers for pre-commit hooks @echo "🪝 Fixing headers for pre-commit..." @python3 .github/tools/fix_file_headers.py --fix-all + +# ============================================================================== +# 🎯 FUZZ TESTING - Automated property-based and security testing +# ============================================================================== +# help: 🎯 FUZZ TESTING - Automated property-based and security testing +# help: fuzz-install - Install fuzzing dependencies (hypothesis, schemathesis, etc.) +# help: fuzz-all - Run complete fuzzing suite (hypothesis + atheris + api + security) +# help: fuzz-hypothesis - Run Hypothesis property-based tests for core validation +# help: fuzz-atheris - Run Atheris coverage-guided fuzzing (requires clang/libfuzzer) +# help: fuzz-api - Run Schemathesis API fuzzing (requires running server) +# help: fuzz-restler - Run RESTler API fuzzing instructions (stateful sequences) +# help: fuzz-restler-auto - Run RESTler via Docker automatically (requires Docker + server) +# help: fuzz-security - Run security-focused vulnerability tests (SQL injection, XSS, etc.) +# help: fuzz-quick - Run quick fuzzing for CI/PR validation (50 examples) +# help: fuzz-extended - Run extended fuzzing for nightly testing (1000+ examples) +# help: fuzz-report - Generate comprehensive fuzzing reports (JSON + Markdown) +# help: fuzz-clean - Clean fuzzing artifacts and generated reports + +fuzz-install: ## 🔧 Install all fuzzing dependencies + @echo "🔧 Installing fuzzing dependencies..." + @test -d "$(VENV_DIR)" || $(MAKE) venv + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + pip install -e .[fuzz]" + @echo "✅ Fuzzing tools installed" + +fuzz-hypothesis: fuzz-install ## 🧪 Run Hypothesis property-based tests + @echo "🧪 Running Hypothesis property-based tests..." + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + python3 -m pytest tests/fuzz/ -v \ + --hypothesis-show-statistics \ + --hypothesis-profile=dev \ + -k 'not (test_sql_injection or test_xss_prevention or test_integer_overflow or test_rate_limiting)' \ + || true" + +fuzz-atheris: ## 🎭 Run Atheris coverage-guided fuzzing + @echo "🎭 Running Atheris coverage-guided fuzzing..." + @echo "⚠️ Atheris requires clang/libfuzzer - skipping for now" + @mkdir -p corpus tests/fuzz/fuzzers/results reports + @echo "✅ Atheris setup completed (requires manual clang installation)" + +fuzz-api: ## 🌐 Run Schemathesis API fuzzing + @echo "🌐 Running Schemathesis API fuzzing..." + @echo "⚠️ API fuzzing requires running server - skipping automated server start" + @echo "💡 To run manually:" + @echo " 1. make dev (in separate terminal)" + @echo " 2. source $(VENV_DIR)/bin/activate && schemathesis run http://localhost:4444/openapi.json --checks all --auth admin:changeme" + @mkdir -p reports + @echo "✅ API fuzzing setup completed" + +fuzz-restler: ## 🧪 Run RESTler API fuzzing (instructions) + @echo "🧪 Running RESTler API fuzzing (via Docker or local install)..." + @echo "⚠️ RESTler is not installed by default; using instructions only" + @mkdir -p reports/restler + @echo "💡 To run with Docker (recommended):" + @echo " 1) make dev # in another terminal" + @echo " 2) curl -sSf http://localhost:4444/openapi.json -o reports/restler/openapi.json" + @echo " 3) docker run --rm -v $$PWD/reports/restler:/workspace ghcr.io/microsoft/restler restler compile --api_spec /workspace/openapi.json" + @echo " 4) docker run --rm -v $$PWD/reports/restler:/workspace ghcr.io/microsoft/restler restler test --grammar_dir /workspace/Compile --no_ssl --time_budget 5" + @echo " # Artifacts will be under reports/restler" + @echo "💡 To run with local install (RESTLER_HOME):" + @echo " export RESTLER_HOME=/path/to/restler && \\" + @echo " $$RESTLER_HOME/restler compile --api_spec reports/restler/openapi.json && \\" + @echo " $$RESTLER_HOME/restler test --grammar_dir Compile --no_ssl --time_budget 5" + @echo "✅ RESTler instructions emitted" + +fuzz-restler-auto: ## 🤖 Run RESTler via Docker automatically (server must be running) + @echo "🤖 Running RESTler via Docker against a running server..." + @if ! command -v docker >/dev/null 2>&1; then \ + echo "🐳 Docker not found; skipping RESTler fuzzing (fuzz-restler-auto)."; \ + echo " Hint: Install Docker or use 'make fuzz-restler' for manual steps."; \ + exit 0; \ + fi + @test -d "$(VENV_DIR)" || $(MAKE) venv + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + python3 tests/fuzz/scripts/run_restler_docker.py" + +fuzz-security: fuzz-install ## 🔐 Run security-focused fuzzing tests + @echo "🔐 Running security-focused fuzzing tests..." + @echo "⚠️ Security tests require running application with auth - they may fail in isolation" + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + HYPOTHESIS_PROFILE=dev python3 -m pytest tests/fuzz/test_security_fuzz.py -v \ + || true" + +fuzz-quick: fuzz-install ## ⚡ Run quick fuzzing for CI + @echo "⚡ Running quick fuzzing for CI..." + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + HYPOTHESIS_PROFILE=ci python3 -m pytest tests/fuzz/ -v \ + -k 'not (test_very_large or test_sql_injection or test_xss_prevention or test_integer_overflow or test_rate_limiting)' \ + || true" + +fuzz-extended: fuzz-install ## 🕐 Run extended fuzzing for nightly runs + @echo "🕐 Running extended fuzzing suite..." + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + HYPOTHESIS_PROFILE=thorough python3 -m pytest tests/fuzz/ -v \ + --durations=20 || true" + +fuzz-report: fuzz-install ## 📊 Generate fuzzing report + @echo "📊 Generating fuzzing report..." + @mkdir -p reports + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + python3 tests/fuzz/scripts/generate_fuzz_report.py" + +fuzz-clean: ## 🧹 Clean fuzzing artifacts + @echo "🧹 Cleaning fuzzing artifacts..." + @rm -rf corpus/ tests/fuzz/fuzzers/results/ reports/schemathesis-report.json + @rm -f reports/fuzz-report.json + +fuzz-all: fuzz-hypothesis fuzz-atheris fuzz-api fuzz-security fuzz-report ## 🎯 Run complete fuzzing suite + @echo "🎯 Complete fuzzing suite finished" diff --git a/docs/docs/testing/fuzzing.md b/docs/docs/testing/fuzzing.md new file mode 100644 index 00000000..20363837 --- /dev/null +++ b/docs/docs/testing/fuzzing.md @@ -0,0 +1,742 @@ +# Fuzz Testing + +MCP Gateway includes comprehensive fuzz testing to automatically discover edge cases, security vulnerabilities, and crashes through property-based testing, coverage-guided fuzzing, and security-focused validation. + +## Overview + +Fuzz testing generates thousands of random, malformed, or edge-case inputs to find bugs that traditional testing might miss. Our implementation combines multiple fuzzing approaches: + +- **Property-Based Testing** with Hypothesis for core validation logic +- **Coverage-Guided Fuzzing** with Atheris for deep code path exploration +- **API Schema Fuzzing** with Schemathesis for contract validation +- **Security-Focused Testing** for vulnerability discovery + +## Quick Start + +### Installation + +Install fuzzing dependencies as an optional package group: + +```bash +# Via Makefile (recommended) +make fuzz-install + +# Or directly with pip +pip install -e .[fuzz] +``` + +### Running Tests + +```bash +# Complete fuzzing suite +make fuzz-all + +# Individual components +make fuzz-hypothesis # Property-based tests +make fuzz-security # Security vulnerability tests +make fuzz-quick # Fast CI validation +make fuzz-report # Generate reports +``` + +## Fuzzing Components + +### Property-Based Testing (Hypothesis) + +Tests core validation logic by generating inputs that satisfy certain properties and verifying invariants hold. + +**Test Modules:** +- `tests/fuzz/test_jsonrpc_fuzz.py` - JSON-RPC validation (16 tests) +- `tests/fuzz/test_jsonpath_fuzz.py` - JSONPath processing (16 tests) +- `tests/fuzz/test_schema_validation_fuzz.py` - Pydantic schemas (19 tests) + +**Example Test:** +```python +@given(st.text()) +def test_validate_request_handles_text_input(self, text_input): + """Test that text input never crashes the validator.""" + try: + data = json.loads(text_input) + if isinstance(data, dict): + validate_request(data) + except (JSONRPCError, ValueError, TypeError, json.JSONDecodeError, AttributeError): + # Expected exceptions for invalid input + pass + except Exception as e: + pytest.fail(f"Unexpected exception: {type(e).__name__}: {e}") +``` + +**Configuration:** +Set testing intensity via environment variables: +```bash +HYPOTHESIS_PROFILE=dev # 100 examples (default) +HYPOTHESIS_PROFILE=ci # 50 examples (fast) +HYPOTHESIS_PROFILE=thorough # 1000 examples (comprehensive) +``` + +### Coverage-Guided Fuzzing (Atheris) + +Uses libfuzzer to instrument code and guide input generation toward unexplored code paths. + +**Fuzzer Scripts:** +- `tests/fuzz/fuzzers/fuzz_jsonpath.py` - JSONPath expression fuzzing +- `tests/fuzz/fuzzers/fuzz_jsonrpc.py` - JSON-RPC message fuzzing +- `tests/fuzz/fuzzers/fuzz_config_parser.py` - Configuration parsing fuzzing + +**Setup Requirements:** +Atheris requires clang and libfuzzer to be installed: + +```bash +# Install LLVM/Clang (one-time setup) +git clone --depth=1 https://github.com/llvm/llvm-project.git +cd llvm-project +cmake -DLLVM_ENABLE_PROJECTS='clang;compiler-rt' -G "Unix Makefiles" -S llvm -B build +cmake --build build --parallel $(nproc) + +# Set environment and install +export CLANG_BIN="$(pwd)/bin/clang" +pip install -e .[fuzz-atheris] +``` + +**Running Atheris:** +```bash +# Manual execution with custom parameters +python tests/fuzz/fuzzers/fuzz_jsonpath.py -runs=10000 -max_total_time=300 +``` + +### API Schema Fuzzing (Schemathesis) + +Tests API endpoints by generating requests based on OpenAPI schema definitions. + +**Features:** +- Validates API contracts automatically +- Tests authentication flows +- Verifies response schemas +- Discovers endpoint-specific edge cases + +**Manual Setup:** +API fuzzing requires a running server instance: + +```bash +# Terminal 1: Start server +make dev + +# Terminal 2: Run API fuzzing +source $(VENV_DIR)/bin/activate +schemathesis run http://localhost:4444/openapi.json \ + --checks all \ + --auth admin:changeme \ + --hypothesis-max-examples=500 +``` + +### Security-Focused Testing + +Tests resistance to common security vulnerabilities and attack patterns. + +**Test Categories:** +- **SQL Injection**: Tests input sanitization in database queries +- **XSS Prevention**: Validates output encoding and CSP headers +- **Path Traversal**: Tests file access controls +- **Command Injection**: Validates command execution safeguards +- **Authentication Bypass**: Tests auth mechanism robustness +- **DoS Protection**: Validates rate limiting and resource constraints + +**Example Security Test:** +```python +@given(st.text(min_size=1, max_size=1000)) +def test_sql_injection_resistance(self, malicious_input): + """Test resistance to SQL injection in various fields.""" + sql_patterns = [ + malicious_input, + f"'; DROP TABLE tools; --", + f"' OR '1'='1", + f"'; INSERT INTO tools (name) VALUES ('hacked'); --", + ] + + for pattern in sql_patterns: + response = client.post("/admin/tools", json={ + "name": pattern, + "url": "http://example.com" + }, headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="}) + + # Should not crash or allow injection + assert response.status_code in [200, 201, 400, 401, 422] +``` + +## Makefile Targets + +| Target | Purpose | Dependencies | Use Case | +|--------|---------|--------------|----------| +| `fuzz-install` | Install fuzzing dependencies | Virtual environment | One-time setup | +| `fuzz-all` | Complete fuzzing suite | `fuzz-install` | Full validation | +| `fuzz-hypothesis` | Property-based testing | `fuzz-install` | Core logic validation | +| `fuzz-atheris` | Coverage-guided fuzzing | clang/libfuzzer | Deep exploration | +| `fuzz-api` | API endpoint fuzzing | Running server | Contract validation | +| `fuzz-restler` | RESTler API fuzzing (instructions) | Docker or local RESTler | Stateful/sequence fuzzing | +| `fuzz-restler-auto` | Run RESTler via Docker automatically | Docker, running server | Automated stateful fuzzing | +| `fuzz-security` | Security vulnerability testing | `fuzz-install` | Security validation | +| `fuzz-quick` | Fast fuzzing for CI | `fuzz-install` | PR validation | +| `fuzz-extended` | Extended fuzzing | `fuzz-install` | Nightly testing | +| `fuzz-report` | Generate reports | `fuzz-install` | Analysis | +| `fuzz-clean` | Clean artifacts | None | Maintenance | + +## Test Execution Modes + +### Development Mode +For interactive development and debugging: +```bash +make fuzz-hypothesis # Run with statistics and detailed output +make fuzz-security # Security tests with warnings +``` + +### CI/CD Mode +For automated testing in continuous integration: +```bash +make fuzz-quick # Fast validation (50 examples) +``` + +### Comprehensive Mode +For thorough testing in nightly builds: +```bash +make fuzz-extended # Extended testing (1000+ examples) +``` + +## RESTler Fuzzing + +RESTler performs stateful, sequence-based fuzzing of REST APIs using the OpenAPI/Swagger specification. It's ideal for discovering bugs that require specific call sequences. + +### Option A: Docker (recommended) + +Prerequisites: Docker installed and the gateway running locally. + +```bash +# Terminal 1: Start the server +make dev + +# Terminal 2: Generate/OpenAPI and run RESTler via Docker +curl -sSf http://localhost:4444/openapi.json -o reports/restler/openapi.json +docker run --rm -v "$PWD/reports/restler:/workspace" \ + ghcr.io/microsoft/restler restler compile --api_spec /workspace/openapi.json +docker run --rm -v "$PWD/reports/restler:/workspace" \ + ghcr.io/microsoft/restler restler test --grammar_dir /workspace/Compile --no_ssl --time_budget 5 + +# Results are written to reports/restler +``` + +You can print these instructions anytime with: + +```bash +make fuzz-restler +``` + +### Option A2: Automated Docker runner + +Use the helper that waits for the server, downloads the spec, then compiles and runs RESTler in Docker: + +```bash +# Terminal 1: Start the server +make dev + +# Terminal 2: Run automated RESTler fuzzing +make fuzz-restler-auto + +# Optional environment variables: +# MCPFUZZ_BASE_URL (default: http://localhost:4444) +# MCPFUZZ_AUTH_HEADER (e.g., "Authorization: Basic YWRtaW46Y2hhbmdlbWU=") +# MCPFUZZ_TIME_BUDGET (minutes, default: 5) +# MCPFUZZ_NO_SSL (1 to pass --no_ssl; default: 1) +``` + +Notes: +- If Docker is not present, `fuzz-restler-auto` will print a friendly message and exit successfully (use `make fuzz-restler` for manual steps). This behavior avoids CI failures on runners without Docker. +- Artifacts are written under `reports/restler/`. + +### Option B: Local install + +Follow RESTler's official installation guide, set `RESTLER_HOME`, then: + +```bash +export RESTLER_HOME=/path/to/restler +curl -sSf http://localhost:4444/openapi.json -o reports/restler/openapi.json +"$RESTLER_HOME"/restler compile --api_spec reports/restler/openapi.json +"$RESTLER_HOME"/restler test --grammar_dir Compile --no_ssl --time_budget 5 +``` + +Notes: +- Ensure the server exposes `http://localhost:4444/openapi.json`. +- For authenticated specs, supply tokens/headers to RESTler as needed. +- Increase `--time_budget` for deeper exploration in nightly runs. + - In CI, prefer running `fuzz-restler-auto` only on runners with Docker available, or skip otherwise. + +## Understanding Results + +### Test Outcomes + +**Passing Tests**: Inputs handled correctly without crashes +**Failing Tests**: Unexpected exceptions or crashes discovered +**Skipped Tests**: Tests requiring external dependencies (auth, servers) + +### Hypothesis Statistics + +Hypothesis provides detailed statistics about test execution: + +``` +- during generate phase (1.86 seconds): + - Typical runtimes: ~ 15-16 ms, of which < 1ms in data generation + - 100 passing examples, 0 failing examples, 0 invalid examples +- Stopped because settings.max_examples=100 +``` + +### Bug Discovery + +When fuzzing finds issues, it provides: +- **Minimal failing example**: Simplified input that reproduces the bug +- **Seed for reproduction**: Run with `--hypothesis-seed=X` to reproduce +- **Call stack**: Exact location where the failure occurred + +Example failure: +``` +Falsifying example: test_validate_request_handles_text_input( + self=, + text_input='null' +) +``` + +## Writing Fuzz Tests + +### Property-Based Test Structure + +```python +from hypothesis import given, strategies as st +import pytest + +class TestMyComponentFuzzing: + @given(st.text(min_size=1, max_size=100)) + def test_component_never_crashes(self, input_text): + """Test that component handles arbitrary text input.""" + try: + result = my_component.process(input_text) + # Verify expected properties + assert isinstance(result, (str, dict, list)) + except (ValueError, TypeError): + # Expected exceptions for invalid input + pass + except Exception as e: + pytest.fail(f"Unexpected exception: {type(e).__name__}: {e}") +``` + +### Atheris Fuzzer Structure + +```python +#!/usr/bin/env python3 +import atheris +import sys +import os + +# Ensure project is in path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..')) + +from mcpgateway.my_module import my_function + +def TestOneInput(data: bytes) -> None: + """Fuzz target for my_function.""" + fdp = atheris.FuzzedDataProvider(data) + + try: + if fdp.remaining_bytes() < 1: + return + + # Generate test input + test_input = fdp.ConsumeUnicodeNoSurrogates(100) + + # Test function (should never crash) + my_function(test_input) + + except (ValueError, TypeError): + # Expected exceptions + pass + except Exception: + # Unexpected - let Atheris catch it + raise + +def main(): + atheris.instrument_all() + atheris.Setup(sys.argv, TestOneInput) + atheris.Fuzz() + +if __name__ == "__main__": + main() +``` + +### Security Test Patterns + +```python +@given(st.text().filter(lambda x: any(char in x for char in '<>"\'&'))) +def test_xss_prevention(self, potentially_malicious): + """Test XSS prevention in user inputs.""" + response = client.post("/api/endpoint", json={ + "field": potentially_malicious + }, headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="}) + + # Should handle malicious content safely + assert response.status_code in [200, 201, 400, 401, 422] + + # Raw script tags should not appear unescaped + if "", + f"javascript:alert('xss')", + f"", + f"", + f"' onmouseover='alert(\"xss\")'", + f"\">", + ] + + for pattern in xss_patterns: + # Test in description field that might be rendered + payload = { + "name": "test-tool", + "url": "http://example.com", + "description": pattern + } + + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + # Should handle potentially malicious content safely + assert response.status_code in [200, 201, 400, 401, 422] + + if response.status_code in [200, 201]: + # If accepted, verify no raw script tags in admin interface + admin_response = client.get( + "/admin", + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + # Raw script tags should not appear unescaped + if "", + "Value%0d%0aX-Injected:%20true", + "Value\x0d\x0aX-Injected: true", + ] + + for pattern in header_injection_patterns: + # Test in custom headers + payload = { + "name": "test-tool", + "url": "http://example.com", + "headers": { + "Custom-Header": pattern, + "Another-Header": pattern + } + } + + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + # Should sanitize or reject header injection attempts + assert response.status_code in [200, 201, 400, 401, 422] + + @given(st.text(min_size=1, max_size=200)) + def test_ldap_injection_resistance(self, input_text): + """Test resistance to LDAP injection attacks.""" + client = TestClient(app) + + ldap_patterns = [ + input_text, + "*)(&(objectClass=*)", + "*)(mail=*))(|(mail=*", + "admin)(&(password=*))", + "*)(&(|(objectClass=*)(uid=*))", + ] + + for pattern in ldap_patterns: + # Test in authentication fields if they exist + payload = { + "name": "test-tool", + "url": "http://example.com", + "auth": { + "auth_type": "basic", + "username": pattern, + "password": pattern + } + } + + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + assert response.status_code in [200, 201, 400, 401, 422] + + def test_xml_injection_resistance(self): + """Test resistance to XML injection attacks.""" + client = TestClient(app) + + xml_patterns = [ + "]>&xxe;", + "]]>injectedtest", + "injectedtest]]>", + "<script>alert('xss')</script>", + ] + + for pattern in xml_patterns: + payload = { + "name": "test-tool", + "url": "http://example.com", + "description": pattern + } + + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + # Should handle XML content safely + assert response.status_code in [200, 201, 400, 401, 422] + + @given(st.binary(min_size=1, max_size=1000)) + def test_binary_input_handling(self, binary_data): + """Test handling of binary data in text fields.""" + client = TestClient(app) + + try: + # Try to decode as various encodings + text_data = binary_data.decode('utf-8', errors='ignore') + + payload = { + "name": text_data[:50], # Limit length + "url": "http://example.com", + "description": text_data[:500] + } + + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + + # Should handle binary/non-UTF8 data gracefully + assert response.status_code in [200, 201, 400, 401, 422] + + except (UnicodeDecodeError, ValueError): + # Expected for some binary data + pass + + def test_authentication_bypass_attempts(self): + """Test various authentication bypass attempts.""" + client = TestClient(app) + + bypass_attempts = [ + "", # Empty auth + "Basic", # Incomplete basic auth + "Basic " + "x" * 1000, # Very long auth + "Bearer fake_token", # Wrong auth type + "Basic YWRtaW46YWRtaW4=", # admin:admin (wrong password) + "Basic cm9vdDpyb290", # root:root + "Basic " + ":" * 100, # Many colons + "Admin admin:changeme", # Wrong scheme + ] + + for auth in bypass_attempts: + headers = {"Authorization": auth} if auth else {} + + response = client.get("/admin/tools", headers=headers) + + # Should require proper authentication + if auth != "Basic YWRtaW46Y2hhbmdlbWU=": # Correct auth + assert response.status_code in [401, 400, 422] + + @given(st.integers(min_value=0, max_value=1000)) + def test_dos_resistance_large_requests(self, size_multiplier): + """Test resistance to DoS via large requests.""" + client = TestClient(app) + + # Create increasingly large payloads + large_string = "x" * (size_multiplier * 100) + + payload = { + "name": f"tool_{size_multiplier}", + "url": "http://example.com", + "description": large_string, + "tags": [f"tag_{i}" for i in range(min(size_multiplier, 100))] + } + + try: + response = client.post( + "/admin/tools", + json=payload, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="}, + timeout=10 # Prevent hanging + ) + + # Should handle large requests gracefully (may reject) + assert response.status_code in [200, 201, 400, 413, 422] + + except Exception: + # Timeout or other errors are acceptable for very large requests + pass + + def test_cors_security(self): + """Test CORS configuration security.""" + client = TestClient(app) + + malicious_origins = [ + "http://evil.com", + "https://phishing-site.com", + "javascript:alert('xss')", + "data:text/html,", + "file:///etc/passwd", + ] + + for origin in malicious_origins: + response = client.options( + "/admin/tools", + headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization" + } + ) + + # Should not allow arbitrary origins + cors_header = response.headers.get("Access-Control-Allow-Origin", "") + if cors_header == "*": + pytest.fail("CORS wildcard (*) allows any origin - security risk") + + # Should not echo back malicious origins + if origin in cors_header and "evil" in origin.lower(): + pytest.fail(f"CORS echoing back potentially malicious origin: {origin}") + + def test_rate_limiting_behavior(self): + """Test rate limiting behavior.""" + client = TestClient(app) + + # Make many rapid requests + responses = [] + for i in range(20): + response = client.post( + "/admin/tools", + json={ + "name": f"rapid_tool_{i}", + "url": "http://example.com" + }, + headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="} + ) + responses.append(response.status_code) + + # Should either accept all or start rate limiting + # Rate limiting typically returns 429 + for status in responses: + assert status in [200, 201, 400, 422, 429, 409] From acb83d3b8173059c043e5f86056e80a9b9f3d5d4 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Sun, 17 Aug 2025 09:50:16 +0100 Subject: [PATCH 08/21] 344 cors security headers (#761) * Update CORS Signed-off-by: Mihai Criveti * Update CORS Signed-off-by: Mihai Criveti * Update CORS ADRs Signed-off-by: Mihai Criveti * Update CORS Signed-off-by: Mihai Criveti * Update CORS Signed-off-by: Mihai Criveti * Fix compose Signed-off-by: Mihai Criveti * Update helm chart Signed-off-by: Mihai Criveti * Update CORS docs Signed-off-by: Mihai Criveti * Update test Signed-off-by: Mihai Criveti --------- Signed-off-by: Mihai Criveti Signed-off-by: Shamsul Arefin --- .env.example | 55 ++ Makefile | 2 +- README.md | 20 +- SECURITY.md | 6 +- charts/mcp-stack/README.md | 460 ++++++------- charts/mcp-stack/values.schema.json | 83 +++ charts/mcp-stack/values.yaml | 19 + docker-compose.yml | 8 +- .../014-security-headers-cors-middleware.md | 238 +++++++ docs/docs/architecture/adr/index.md | 1 + docs/docs/architecture/roadmap.md | 4 +- docs/docs/architecture/security-features.md | 11 +- docs/docs/manage/securing.md | 19 +- mcpgateway/admin.py | 4 +- mcpgateway/config.py | 41 ++ mcpgateway/main.py | 21 +- mcpgateway/middleware/__init__.py | 7 + mcpgateway/middleware/security_headers.py | 100 +++ mcpgateway/services/prompt_service.py | 6 +- mcpgateway/services/resource_service.py | 6 +- mcpgateway/services/server_service.py | 6 +- mcpgateway/services/tool_service.py | 6 +- mcpgateway/templates/admin.html | 8 + mcpgateway/utils/security_cookies.py | 105 +++ tests/async/test_async_safety.py | 5 +- tests/security/test_configurable_headers.py | 152 +++++ tests/security/test_security_cookies.py | 222 +++++++ tests/security/test_security_headers.py | 246 +++++++ .../test_security_middleware_comprehensive.py | 628 ++++++++++++++++++ ...test_security_performance_compatibility.py | 605 +++++++++++++++++ tests/security/test_standalone_middleware.py | 119 ++++ tests/unit/mcpgateway/test_admin.py | 12 +- 32 files changed, 2970 insertions(+), 255 deletions(-) create mode 100644 docs/docs/architecture/adr/014-security-headers-cors-middleware.md create mode 100644 mcpgateway/middleware/__init__.py create mode 100644 mcpgateway/middleware/security_headers.py create mode 100644 mcpgateway/utils/security_cookies.py create mode 100644 tests/security/test_configurable_headers.py create mode 100644 tests/security/test_security_cookies.py create mode 100644 tests/security/test_security_headers.py create mode 100644 tests/security/test_security_middleware_comprehensive.py create mode 100644 tests/security/test_security_performance_compatibility.py create mode 100644 tests/security/test_standalone_middleware.py diff --git a/.env.example b/.env.example index a1e6a427..06a5c0c7 100644 --- a/.env.example +++ b/.env.example @@ -82,6 +82,7 @@ PROTOCOL_VERSION=2025-03-26 ##################################### # Admin UI basic-auth credentials +# PRODUCTION: Change these to strong, unique values! BASIC_AUTH_USER=admin BASIC_AUTH_PASSWORD=changeme @@ -89,6 +90,7 @@ BASIC_AUTH_PASSWORD=changeme AUTH_REQUIRED=true # Secret used to sign JWTs (use long random value in prod) +# PRODUCTION: Use a strong, random secret (minimum 32 characters) JWT_SECRET_KEY=my-test-key # Algorithm used to sign JWTs (e.g., HS256) @@ -119,9 +121,11 @@ AUTH_ENCRYPTION_SECRET=my-test-salt ##################################### # Enable the visual Admin UI (true/false) +# PRODUCTION: Set to false for security MCPGATEWAY_UI_ENABLED=true # Enable the Admin API endpoints (true/false) +# PRODUCTION: Set to false for security MCPGATEWAY_ADMIN_API_ENABLED=true # Enable bulk import endpoint for tools (true/false) @@ -157,6 +161,57 @@ ALLOWED_ORIGINS='["http://localhost", "http://localhost:4444"]' # Enable CORS handling in the gateway CORS_ENABLED=true +# CORS allow credentials (true/false) +CORS_ALLOW_CREDENTIALS=true + +# Environment setting (development/production) - affects security defaults +# development: Auto-configures CORS for localhost:3000, localhost:8080, etc. +# production: Uses APP_DOMAIN for HTTPS origins, enforces secure cookies +ENVIRONMENT=development + +# Domain configuration for production CORS origins +# In production, automatically creates origins: https://APP_DOMAIN, https://app.APP_DOMAIN, https://admin.APP_DOMAIN +# For production: set to your actual domain (e.g., mycompany.com) +APP_DOMAIN=localhost + +# Security settings for cookies +# production: Automatically enables secure cookies regardless of this setting +# development: Set to false for HTTP development, true for HTTPS +SECURE_COOKIES=true + +# Cookie SameSite attribute for CSRF protection +# strict: Maximum security, may break some OAuth flows +# lax: Good balance of security and compatibility (recommended) +# none: Requires Secure=true, allows cross-site usage +COOKIE_SAMESITE=lax + +##################################### +# Security Headers Configuration +##################################### + +# Enable security headers middleware (true/false) +SECURITY_HEADERS_ENABLED=true + +# X-Frame-Options setting (DENY, SAMEORIGIN, or ALLOW-FROM uri) +# DENY: Prevents all iframe embedding (recommended for security) +# SAMEORIGIN: Allows embedding from same domain only +# To disable: Set to empty string X_FRAME_OPTIONS="" +X_FRAME_OPTIONS=DENY + +# Other security headers (true/false) +X_CONTENT_TYPE_OPTIONS_ENABLED=true +X_XSS_PROTECTION_ENABLED=true +X_DOWNLOAD_OPTIONS_ENABLED=true + +# HSTS (HTTP Strict Transport Security) settings +HSTS_ENABLED=true +# HSTS max age in seconds (31536000 = 1 year) +HSTS_MAX_AGE=31536000 +HSTS_INCLUDE_SUBDOMAINS=true + +# Remove server identification headers (true/false) +REMOVE_SERVER_HEADERS=true + # Enable HTTP Basic Auth for docs endpoints (in addition to Bearer token auth) # Uses the same credentials as BASIC_AUTH_USER and BASIC_AUTH_PASSWORD DOCS_ALLOW_BASIC_AUTH=false diff --git a/Makefile b/Makefile index f5a847e0..64f28e57 100644 --- a/Makefile +++ b/Makefile @@ -1389,7 +1389,7 @@ install-web-linters: nodejsscan: @echo "🔒 Running nodejsscan for JavaScript security vulnerabilities..." $(call ensure_pip_package,nodejsscan) - @$(VENV_DIR)/bin/nodejsscan --directory ./mcpgateway/static || true + @$(VENV_DIR)/bin/nodejsscan --directory ./mcpgateway/static --directory ./mcpgateway/templates || true lint-web: install-web-linters nodejsscan @echo "🔍 Linting HTML files..." diff --git a/README.md b/README.md index 6139e91f..c5e80d26 100644 --- a/README.md +++ b/README.md @@ -1053,10 +1053,28 @@ You can get started by copying the provided [.env.example](.env.example) to `.en | Setting | Description | Default | Options | | ------------------------- | ------------------------------ | ---------------------------------------------- | ---------- | | `SKIP_SSL_VERIFY` | Skip upstream TLS verification | `false` | bool | -| `ALLOWED_ORIGINS` | CORS allow-list | `["http://localhost","http://localhost:4444"]` | JSON array | +| `ENVIRONMENT` | Deployment environment (affects security defaults) | `development` | `development`/`production` | +| `APP_DOMAIN` | Domain for production CORS origins | `localhost` | string | +| `ALLOWED_ORIGINS` | CORS allow-list | Auto-configured by environment | JSON array | | `CORS_ENABLED` | Enable CORS | `true` | bool | +| `CORS_ALLOW_CREDENTIALS` | Allow credentials in CORS | `true` | bool | +| `SECURE_COOKIES` | Force secure cookie flags | `true` | bool | +| `COOKIE_SAMESITE` | Cookie SameSite attribute | `lax` | `strict`/`lax`/`none` | +| `SECURITY_HEADERS_ENABLED` | Enable security headers middleware | `true` | bool | +| `X_FRAME_OPTIONS` | X-Frame-Options header value | `DENY` | `DENY`/`SAMEORIGIN` | +| `HSTS_ENABLED` | Enable HSTS header | `true` | bool | +| `HSTS_MAX_AGE` | HSTS max age in seconds | `31536000` | int | +| `REMOVE_SERVER_HEADERS` | Remove server identification | `true` | bool | | `DOCS_ALLOW_BASIC_AUTH` | Allow Basic Auth for docs (in addition to JWT) | `false` | bool | +> **CORS Configuration**: When `ENVIRONMENT=development`, CORS origins are automatically configured for common development ports (3000, 8080, gateway port). In production, origins are constructed from `APP_DOMAIN` (e.g., `https://yourdomain.com`, `https://app.yourdomain.com`). You can override this by explicitly setting `ALLOWED_ORIGINS`. +> +> **Security Headers**: The gateway automatically adds configurable security headers to all responses including CSP, X-Frame-Options, X-Content-Type-Options, X-Download-Options, and HSTS (on HTTPS). All headers can be individually enabled/disabled. Sensitive server headers are removed. +> +> **iframe Embedding**: By default, `X-Frame-Options: DENY` prevents iframe embedding for security. To allow embedding, set `X_FRAME_OPTIONS=SAMEORIGIN` (same domain) or disable with `X_FRAME_OPTIONS=""`. Also update CSP `frame-ancestors` directive if needed. +> +> **Cookie Security**: Authentication cookies are automatically configured with HttpOnly, Secure (in production), and SameSite attributes for CSRF protection. +> > Note: do not quote the ALLOWED_ORIGINS values, this needs to be valid JSON, such as: > ALLOWED_ORIGINS=["http://localhost", "http://localhost:4444"] > diff --git a/SECURITY.md b/SECURITY.md index 76ece233..f5ed7223 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -236,7 +236,7 @@ Applications consuming data from MCP Gateway should: - **Never trust data implicitly** - validate all inputs - **Implement context-appropriate sanitization** for their UI framework -- **Use Content Security Policy (CSP)** headers +- **Use Content Security Policy (CSP)** headers (automatically provided by MCP Gateway) - **Escape data appropriately** for the output context (HTML, JavaScript, SQL, etc.) - **Implement their own authentication** and authorization - **Monitor for security anomalies** in rendered content @@ -261,7 +261,9 @@ When deploying MCP Gateway in production: - [ ] Configure resource limits (CPU, memory) to prevent DoS attacks - [ ] Implement proper secrets management (never hardcode credentials) - [ ] Set up structured logging without exposing sensitive data -- [ ] Configure CORS policies appropriately for your clients +- [ ] Configure CORS policies appropriately for your clients (auto-configured by ENVIRONMENT setting) +- [ ] Verify security headers are working (automatically added by SecurityHeadersMiddleware) +- [ ] Configure iframe embedding policy (X-Frame-Options: DENY by default, set to SAMEORIGIN if embedding needed) - [ ] Disable debug mode and verbose error messages in production - [ ] Implement backup and disaster recovery procedures - [ ] Document incident response procedures diff --git a/charts/mcp-stack/README.md b/charts/mcp-stack/README.md index 7f55ad48..3489aea2 100644 --- a/charts/mcp-stack/README.md +++ b/charts/mcp-stack/README.md @@ -29,277 +29,291 @@ Kubernetes: `>=1.21.0` | Key | Type | Default | Description | |-----|------|---------|-------------| -| global.fullnameOverride | string | `""` | | | global.imagePullSecrets | list | `[]` | | | global.nameOverride | string | `""` | | -| mcpContextForge.config.ALLOWED_ORIGINS | string | `"[\"http://localhost\",\"http://localhost:4444\"]"` | | +| global.fullnameOverride | string | `""` | | +| mcpContextForge.replicaCount | int | `2` | | +| mcpContextForge.hpa | object | `{"enabled":true,"maxReplicas":10,"minReplicas":2,"targetCPUUtilizationPercentage":90,"targetMemoryUtilizationPercentage":90}` | ------------------------------------------------------------------ | +| mcpContextForge.image.repository | string | `"ghcr.io/ibm/mcp-context-forge"` | | +| mcpContextForge.image.tag | string | `"latest"` | | +| mcpContextForge.image.pullPolicy | string | `"Always"` | | +| mcpContextForge.service.type | string | `"ClusterIP"` | | +| mcpContextForge.service.port | int | `80` | | +| mcpContextForge.containerPort | int | `4444` | | +| mcpContextForge.probes.startup.type | string | `"exec"` | | +| mcpContextForge.probes.startup.command[0] | string | `"sh"` | | +| mcpContextForge.probes.startup.command[1] | string | `"-c"` | | +| mcpContextForge.probes.startup.command[2] | string | `"sleep 10"` | | +| mcpContextForge.probes.startup.timeoutSeconds | int | `15` | | +| mcpContextForge.probes.startup.periodSeconds | int | `5` | | +| mcpContextForge.probes.startup.failureThreshold | int | `1` | | +| mcpContextForge.probes.readiness.type | string | `"http"` | | +| mcpContextForge.probes.readiness.path | string | `"/ready"` | | +| mcpContextForge.probes.readiness.port | int | `4444` | | +| mcpContextForge.probes.readiness.initialDelaySeconds | int | `15` | | +| mcpContextForge.probes.readiness.periodSeconds | int | `10` | | +| mcpContextForge.probes.readiness.timeoutSeconds | int | `2` | | +| mcpContextForge.probes.readiness.successThreshold | int | `1` | | +| mcpContextForge.probes.readiness.failureThreshold | int | `3` | | +| mcpContextForge.probes.liveness.type | string | `"http"` | | +| mcpContextForge.probes.liveness.path | string | `"/health"` | | +| mcpContextForge.probes.liveness.port | int | `4444` | | +| mcpContextForge.probes.liveness.initialDelaySeconds | int | `10` | | +| mcpContextForge.probes.liveness.periodSeconds | int | `15` | | +| mcpContextForge.probes.liveness.timeoutSeconds | int | `2` | | +| mcpContextForge.probes.liveness.successThreshold | int | `1` | | +| mcpContextForge.probes.liveness.failureThreshold | int | `3` | | +| mcpContextForge.resources.limits.cpu | string | `"200m"` | | +| mcpContextForge.resources.limits.memory | string | `"1024Mi"` | | +| mcpContextForge.resources.requests.cpu | string | `"100m"` | | +| mcpContextForge.resources.requests.memory | string | `"512Mi"` | | +| mcpContextForge.ingress.enabled | bool | `true` | | +| mcpContextForge.ingress.className | string | `"nginx"` | | +| mcpContextForge.ingress.host | string | `"gateway.local"` | | +| mcpContextForge.ingress.path | string | `"/"` | | +| mcpContextForge.ingress.pathType | string | `"Prefix"` | | +| mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/rewrite-target" | string | `"/"` | | +| mcpContextForge.env.host | string | `"0.0.0.0"` | | +| mcpContextForge.env.postgres.port | int | `5432` | | +| mcpContextForge.env.postgres.db | string | `"postgresdb"` | | +| mcpContextForge.env.postgres.userKey | string | `"POSTGRES_USER"` | | +| mcpContextForge.env.postgres.passwordKey | string | `"POSTGRES_PASSWORD"` | | +| mcpContextForge.env.redis.port | int | `6379` | | +| mcpContextForge.config.GUNICORN_WORKERS | string | `"2"` | | +| mcpContextForge.config.GUNICORN_TIMEOUT | string | `"600"` | | +| mcpContextForge.config.GUNICORN_MAX_REQUESTS | string | `"10000"` | | +| mcpContextForge.config.GUNICORN_MAX_REQUESTS_JITTER | string | `"100"` | | +| mcpContextForge.config.GUNICORN_PRELOAD_APP | string | `"true"` | | | mcpContextForge.config.APP_NAME | string | `"MCP_Gateway"` | | +| mcpContextForge.config.HOST | string | `"0.0.0.0"` | | +| mcpContextForge.config.PORT | string | `"4444"` | | | mcpContextForge.config.APP_ROOT_PATH | string | `""` | | -| mcpContextForge.config.CACHE_PREFIX | string | `"mcpgw"` | | -| mcpContextForge.config.CACHE_TYPE | string | `"redis"` | | -| mcpContextForge.config.CORS_ENABLED | string | `"true"` | | -| mcpContextForge.config.DB_MAX_OVERFLOW | string | `"10"` | | -| mcpContextForge.config.DB_MAX_RETRIES | string | `"3"` | | -| mcpContextForge.config.DB_POOL_RECYCLE | string | `"3600"` | | | mcpContextForge.config.DB_POOL_SIZE | string | `"200"` | | +| mcpContextForge.config.DB_MAX_OVERFLOW | string | `"10"` | | | mcpContextForge.config.DB_POOL_TIMEOUT | string | `"30"` | | +| mcpContextForge.config.DB_POOL_RECYCLE | string | `"3600"` | | +| mcpContextForge.config.CACHE_TYPE | string | `"redis"` | | +| mcpContextForge.config.CACHE_PREFIX | string | `"mcpgw"` | | +| mcpContextForge.config.SESSION_TTL | string | `"3600"` | | +| mcpContextForge.config.MESSAGE_TTL | string | `"600"` | | +| mcpContextForge.config.REDIS_MAX_RETRIES | string | `"3"` | | +| mcpContextForge.config.REDIS_RETRY_INTERVAL_MS | string | `"2000"` | | +| mcpContextForge.config.DB_MAX_RETRIES | string | `"3"` | | | mcpContextForge.config.DB_RETRY_INTERVAL_MS | string | `"2000"` | | -| mcpContextForge.config.DEBUG | string | `"false"` | | -| mcpContextForge.config.DEV_MODE | string | `"false"` | | -| mcpContextForge.config.FEDERATION_DISCOVERY | string | `"false"` | | +| mcpContextForge.config.PROTOCOL_VERSION | string | `"2025-03-26"` | | +| mcpContextForge.config.MCPGATEWAY_UI_ENABLED | string | `"true"` | | +| mcpContextForge.config.MCPGATEWAY_ADMIN_API_ENABLED | string | `"true"` | | +| mcpContextForge.config.ENVIRONMENT | string | `"development"` | | +| mcpContextForge.config.APP_DOMAIN | string | `"localhost"` | | +| mcpContextForge.config.CORS_ENABLED | string | `"true"` | | +| mcpContextForge.config.CORS_ALLOW_CREDENTIALS | string | `"true"` | | +| mcpContextForge.config.ALLOWED_ORIGINS | string | `"[\"http://localhost\",\"http://localhost:4444\"]"` | | +| mcpContextForge.config.SKIP_SSL_VERIFY | string | `"false"` | | +| mcpContextForge.config.SECURITY_HEADERS_ENABLED | string | `"true"` | | +| mcpContextForge.config.X_FRAME_OPTIONS | string | `"DENY"` | | +| mcpContextForge.config.X_CONTENT_TYPE_OPTIONS_ENABLED | string | `"true"` | | +| mcpContextForge.config.X_XSS_PROTECTION_ENABLED | string | `"true"` | | +| mcpContextForge.config.X_DOWNLOAD_OPTIONS_ENABLED | string | `"true"` | | +| mcpContextForge.config.HSTS_ENABLED | string | `"true"` | | +| mcpContextForge.config.HSTS_MAX_AGE | string | `"31536000"` | | +| mcpContextForge.config.HSTS_INCLUDE_SUBDOMAINS | string | `"true"` | | +| mcpContextForge.config.REMOVE_SERVER_HEADERS | string | `"true"` | | +| mcpContextForge.config.SECURE_COOKIES | string | `"true"` | | +| mcpContextForge.config.COOKIE_SAMESITE | string | `"lax"` | | +| mcpContextForge.config.LOG_LEVEL | string | `"INFO"` | | +| mcpContextForge.config.LOG_FORMAT | string | `"json"` | | +| mcpContextForge.config.TRANSPORT_TYPE | string | `"all"` | | +| mcpContextForge.config.WEBSOCKET_PING_INTERVAL | string | `"30"` | | +| mcpContextForge.config.SSE_RETRY_TIMEOUT | string | `"5000"` | | +| mcpContextForge.config.USE_STATEFUL_SESSIONS | string | `"false"` | | +| mcpContextForge.config.JSON_RESPONSE_ENABLED | string | `"true"` | | | mcpContextForge.config.FEDERATION_ENABLED | string | `"true"` | | +| mcpContextForge.config.FEDERATION_DISCOVERY | string | `"false"` | | | mcpContextForge.config.FEDERATION_PEERS | string | `"[]"` | | -| mcpContextForge.config.FEDERATION_SYNC_INTERVAL | string | `"300"` | | | mcpContextForge.config.FEDERATION_TIMEOUT | string | `"30"` | | -| mcpContextForge.config.FILELOCK_NAME | string | `"gateway_healthcheck_init.lock"` | | -| mcpContextForge.config.GUNICORN_MAX_REQUESTS | string | `"10000"` | | -| mcpContextForge.config.GUNICORN_MAX_REQUESTS_JITTER | string | `"100"` | | -| mcpContextForge.config.GUNICORN_PRELOAD_APP | string | `"true"` | | -| mcpContextForge.config.GUNICORN_TIMEOUT | string | `"600"` | | -| mcpContextForge.config.GUNICORN_WORKERS | string | `"2"` | | -| mcpContextForge.config.HEALTH_CHECK_INTERVAL | string | `"60"` | | -| mcpContextForge.config.HEALTH_CHECK_TIMEOUT | string | `"10"` | | -| mcpContextForge.config.HOST | string | `"0.0.0.0"` | | -| mcpContextForge.config.JSON_RESPONSE_ENABLED | string | `"true"` | | -| mcpContextForge.config.LOG_FORMAT | string | `"json"` | | -| mcpContextForge.config.LOG_LEVEL | string | `"INFO"` | | -| mcpContextForge.config.MAX_PROMPT_SIZE | string | `"102400"` | | +| mcpContextForge.config.FEDERATION_SYNC_INTERVAL | string | `"300"` | | +| mcpContextForge.config.RESOURCE_CACHE_SIZE | string | `"1000"` | | +| mcpContextForge.config.RESOURCE_CACHE_TTL | string | `"3600"` | | | mcpContextForge.config.MAX_RESOURCE_SIZE | string | `"10485760"` | | +| mcpContextForge.config.TOOL_TIMEOUT | string | `"60"` | | | mcpContextForge.config.MAX_TOOL_RETRIES | string | `"3"` | | -| mcpContextForge.config.MCPGATEWAY_ADMIN_API_ENABLED | string | `"true"` | | -| mcpContextForge.config.MCPGATEWAY_UI_ENABLED | string | `"true"` | | -| mcpContextForge.config.MESSAGE_TTL | string | `"600"` | | -| mcpContextForge.config.PORT | string | `"4444"` | | +| mcpContextForge.config.TOOL_RATE_LIMIT | string | `"100"` | | +| mcpContextForge.config.TOOL_CONCURRENT_LIMIT | string | `"10"` | | | mcpContextForge.config.PROMPT_CACHE_SIZE | string | `"100"` | | +| mcpContextForge.config.MAX_PROMPT_SIZE | string | `"102400"` | | | mcpContextForge.config.PROMPT_RENDER_TIMEOUT | string | `"10"` | | -| mcpContextForge.config.PROTOCOL_VERSION | string | `"2025-03-26"` | | -| mcpContextForge.config.REDIS_MAX_RETRIES | string | `"3"` | | -| mcpContextForge.config.REDIS_RETRY_INTERVAL_MS | string | `"2000"` | | -| mcpContextForge.config.RELOAD | string | `"false"` | | -| mcpContextForge.config.RESOURCE_CACHE_SIZE | string | `"1000"` | | -| mcpContextForge.config.RESOURCE_CACHE_TTL | string | `"3600"` | | -| mcpContextForge.config.SESSION_TTL | string | `"3600"` | | -| mcpContextForge.config.SKIP_SSL_VERIFY | string | `"false"` | | -| mcpContextForge.config.SSE_RETRY_TIMEOUT | string | `"5000"` | | -| mcpContextForge.config.TOOL_CONCURRENT_LIMIT | string | `"10"` | | -| mcpContextForge.config.TOOL_RATE_LIMIT | string | `"100"` | | -| mcpContextForge.config.TOOL_TIMEOUT | string | `"60"` | | -| mcpContextForge.config.TRANSPORT_TYPE | string | `"all"` | | +| mcpContextForge.config.HEALTH_CHECK_INTERVAL | string | `"60"` | | +| mcpContextForge.config.HEALTH_CHECK_TIMEOUT | string | `"10"` | | | mcpContextForge.config.UNHEALTHY_THRESHOLD | string | `"3"` | | -| mcpContextForge.config.USE_STATEFUL_SESSIONS | string | `"false"` | | -| mcpContextForge.config.WEBSOCKET_PING_INTERVAL | string | `"30"` | | -| mcpContextForge.containerPort | int | `4444` | | -| mcpContextForge.env.host | string | `"0.0.0.0"` | | -| mcpContextForge.env.postgres.db | string | `"postgresdb"` | | -| mcpContextForge.env.postgres.passwordKey | string | `"POSTGRES_PASSWORD"` | | -| mcpContextForge.env.postgres.port | int | `5432` | | -| mcpContextForge.env.postgres.userKey | string | `"POSTGRES_USER"` | | -| mcpContextForge.env.redis.port | int | `6379` | | -| mcpContextForge.envFrom[0].secretRef.name | string | `"mcp-gateway-secret"` | | -| mcpContextForge.envFrom[1].configMapRef.name | string | `"mcp-gateway-config"` | | -| mcpContextForge.hpa | object | `{"enabled":true,"maxReplicas":10,"minReplicas":2,"targetCPUUtilizationPercentage":90,"targetMemoryUtilizationPercentage":90}` | ------------------------------------------------------------------ | -| mcpContextForge.image.pullPolicy | string | `"Always"` | | -| mcpContextForge.image.repository | string | `"ghcr.io/ibm/mcp-context-forge"` | | -| mcpContextForge.image.tag | string | `"latest"` | | -| mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/rewrite-target" | string | `"/"` | | -| mcpContextForge.ingress.className | string | `"nginx"` | | -| mcpContextForge.ingress.enabled | bool | `true` | | -| mcpContextForge.ingress.host | string | `"gateway.local"` | | -| mcpContextForge.ingress.path | string | `"/"` | | -| mcpContextForge.ingress.pathType | string | `"Prefix"` | | -| mcpContextForge.probes.liveness.failureThreshold | int | `3` | | -| mcpContextForge.probes.liveness.initialDelaySeconds | int | `10` | | -| mcpContextForge.probes.liveness.path | string | `"/health"` | | -| mcpContextForge.probes.liveness.periodSeconds | int | `15` | | -| mcpContextForge.probes.liveness.port | int | `4444` | | -| mcpContextForge.probes.liveness.successThreshold | int | `1` | | -| mcpContextForge.probes.liveness.timeoutSeconds | int | `2` | | -| mcpContextForge.probes.liveness.type | string | `"http"` | | -| mcpContextForge.probes.readiness.failureThreshold | int | `3` | | -| mcpContextForge.probes.readiness.initialDelaySeconds | int | `15` | | -| mcpContextForge.probes.readiness.path | string | `"/ready"` | | -| mcpContextForge.probes.readiness.periodSeconds | int | `10` | | -| mcpContextForge.probes.readiness.port | int | `4444` | | -| mcpContextForge.probes.readiness.successThreshold | int | `1` | | -| mcpContextForge.probes.readiness.timeoutSeconds | int | `2` | | -| mcpContextForge.probes.readiness.type | string | `"http"` | | -| mcpContextForge.probes.startup.command[0] | string | `"sh"` | | -| mcpContextForge.probes.startup.command[1] | string | `"-c"` | | -| mcpContextForge.probes.startup.command[2] | string | `"sleep 10"` | | -| mcpContextForge.probes.startup.failureThreshold | int | `1` | | -| mcpContextForge.probes.startup.periodSeconds | int | `5` | | -| mcpContextForge.probes.startup.timeoutSeconds | int | `15` | | -| mcpContextForge.probes.startup.type | string | `"exec"` | | -| mcpContextForge.replicaCount | int | `2` | | -| mcpContextForge.resources.limits.cpu | string | `"200m"` | | -| mcpContextForge.resources.limits.memory | string | `"1024Mi"` | | -| mcpContextForge.resources.requests.cpu | string | `"100m"` | | -| mcpContextForge.resources.requests.memory | string | `"512Mi"` | | -| mcpContextForge.secret.AUTH_ENCRYPTION_SECRET | string | `"my-test-salt"` | | -| mcpContextForge.secret.AUTH_REQUIRED | string | `"true"` | | -| mcpContextForge.secret.BASIC_AUTH_PASSWORD | string | `"changeme"` | | +| mcpContextForge.config.FILELOCK_NAME | string | `"gateway_healthcheck_init.lock"` | | +| mcpContextForge.config.DEV_MODE | string | `"false"` | | +| mcpContextForge.config.RELOAD | string | `"false"` | | +| mcpContextForge.config.DEBUG | string | `"false"` | | | mcpContextForge.secret.BASIC_AUTH_USER | string | `"admin"` | | -| mcpContextForge.secret.JWT_ALGORITHM | string | `"HS256"` | | +| mcpContextForge.secret.BASIC_AUTH_PASSWORD | string | `"changeme"` | | +| mcpContextForge.secret.AUTH_REQUIRED | string | `"true"` | | | mcpContextForge.secret.JWT_SECRET_KEY | string | `"my-test-key"` | | +| mcpContextForge.secret.JWT_ALGORITHM | string | `"HS256"` | | | mcpContextForge.secret.TOKEN_EXPIRY | string | `"10080"` | | -| mcpContextForge.service.port | int | `80` | | -| mcpContextForge.service.type | string | `"ClusterIP"` | | -| mcpFastTimeServer.enabled | bool | `true` | | -| mcpFastTimeServer.image.pullPolicy | string | `"IfNotPresent"` | | -| mcpFastTimeServer.image.repository | string | `"ghcr.io/ibm/fast-time-server"` | | -| mcpFastTimeServer.image.tag | string | `"0.5.0"` | | -| mcpFastTimeServer.ingress.enabled | bool | `true` | | -| mcpFastTimeServer.ingress.path | string | `"/fast-time"` | | -| mcpFastTimeServer.ingress.pathType | string | `"Prefix"` | | -| mcpFastTimeServer.ingress.servicePort | int | `80` | | -| mcpFastTimeServer.port | int | `8080` | | -| mcpFastTimeServer.probes.liveness.failureThreshold | int | `3` | | -| mcpFastTimeServer.probes.liveness.initialDelaySeconds | int | `3` | | -| mcpFastTimeServer.probes.liveness.path | string | `"/health"` | | -| mcpFastTimeServer.probes.liveness.periodSeconds | int | `15` | | -| mcpFastTimeServer.probes.liveness.port | int | `8080` | | -| mcpFastTimeServer.probes.liveness.successThreshold | int | `1` | | -| mcpFastTimeServer.probes.liveness.timeoutSeconds | int | `2` | | -| mcpFastTimeServer.probes.liveness.type | string | `"http"` | | -| mcpFastTimeServer.probes.readiness.failureThreshold | int | `3` | | -| mcpFastTimeServer.probes.readiness.initialDelaySeconds | int | `3` | | -| mcpFastTimeServer.probes.readiness.path | string | `"/health"` | | -| mcpFastTimeServer.probes.readiness.periodSeconds | int | `10` | | -| mcpFastTimeServer.probes.readiness.port | int | `8080` | | -| mcpFastTimeServer.probes.readiness.successThreshold | int | `1` | | -| mcpFastTimeServer.probes.readiness.timeoutSeconds | int | `2` | | -| mcpFastTimeServer.probes.readiness.type | string | `"http"` | | -| mcpFastTimeServer.replicaCount | int | `2` | | -| mcpFastTimeServer.resources.limits.cpu | string | `"50m"` | | -| mcpFastTimeServer.resources.limits.memory | string | `"64Mi"` | | -| mcpFastTimeServer.resources.requests.cpu | string | `"25m"` | | -| mcpFastTimeServer.resources.requests.memory | string | `"10Mi"` | | -| migration.activeDeadlineSeconds | int | `600` | | -| migration.backoffLimit | int | `3` | | -| migration.command.migrate | string | `"alembic upgrade head || echo '⚠️ Migration check failed'"` | | -| migration.command.waitForDb | string | `"python3 /app/mcpgateway/utils/db_isready.py --max-tries 30 --interval 2 --timeout 5"` | | +| mcpContextForge.secret.AUTH_ENCRYPTION_SECRET | string | `"my-test-salt"` | | +| mcpContextForge.envFrom[0].secretRef.name | string | `"mcp-gateway-secret"` | | +| mcpContextForge.envFrom[1].configMapRef.name | string | `"mcp-gateway-config"` | | | migration.enabled | bool | `true` | | -| migration.image.pullPolicy | string | `"Always"` | | +| migration.restartPolicy | string | `"Never"` | | +| migration.backoffLimit | int | `3` | | +| migration.activeDeadlineSeconds | int | `600` | | | migration.image.repository | string | `"ghcr.io/ibm/mcp-context-forge"` | | | migration.image.tag | string | `"latest"` | | +| migration.image.pullPolicy | string | `"Always"` | | | migration.resources.limits.cpu | string | `"200m"` | | | migration.resources.limits.memory | string | `"512Mi"` | | | migration.resources.requests.cpu | string | `"100m"` | | | migration.resources.requests.memory | string | `"256Mi"` | | -| migration.restartPolicy | string | `"Never"` | | -| pgadmin.enabled | bool | `true` | | -| pgadmin.env.email | string | `"admin@example.com"` | | -| pgadmin.env.password | string | `"admin123"` | | -| pgadmin.image.pullPolicy | string | `"IfNotPresent"` | | -| pgadmin.image.repository | string | `"dpage/pgadmin4"` | | -| pgadmin.image.tag | string | `"latest"` | | -| pgadmin.probes.liveness.failureThreshold | int | `5` | | -| pgadmin.probes.liveness.initialDelaySeconds | int | `10` | | -| pgadmin.probes.liveness.path | string | `"/misc/ping"` | | -| pgadmin.probes.liveness.periodSeconds | int | `15` | | -| pgadmin.probes.liveness.port | int | `80` | | -| pgadmin.probes.liveness.successThreshold | int | `1` | | -| pgadmin.probes.liveness.timeoutSeconds | int | `2` | | -| pgadmin.probes.liveness.type | string | `"http"` | | -| pgadmin.probes.readiness.failureThreshold | int | `3` | | -| pgadmin.probes.readiness.initialDelaySeconds | int | `15` | | -| pgadmin.probes.readiness.path | string | `"/misc/ping"` | | -| pgadmin.probes.readiness.periodSeconds | int | `10` | | -| pgadmin.probes.readiness.port | int | `80` | | -| pgadmin.probes.readiness.successThreshold | int | `1` | | -| pgadmin.probes.readiness.timeoutSeconds | int | `2` | | -| pgadmin.probes.readiness.type | string | `"http"` | | -| pgadmin.resources.limits.cpu | string | `"200m"` | | -| pgadmin.resources.limits.memory | string | `"256Mi"` | | -| pgadmin.resources.requests.cpu | string | `"100m"` | | -| pgadmin.resources.requests.memory | string | `"128Mi"` | | -| pgadmin.service.port | int | `80` | | -| pgadmin.service.type | string | `"ClusterIP"` | | -| postgres.credentials.database | string | `"postgresdb"` | | -| postgres.credentials.password | string | `"test123"` | | -| postgres.credentials.user | string | `"admin"` | | +| migration.command.waitForDb | string | `"python3 /app/mcpgateway/utils/db_isready.py --max-tries 30 --interval 2 --timeout 5"` | | +| migration.command.migrate | string | `"alembic upgrade head || echo '⚠️ Migration check failed'"` | | | postgres.enabled | bool | `true` | | -| postgres.existingSecret | string | `""` | | -| postgres.image.pullPolicy | string | `"IfNotPresent"` | | | postgres.image.repository | string | `"postgres"` | | | postgres.image.tag | string | `"17"` | | -| postgres.persistence.accessModes[0] | string | `"ReadWriteMany"` | | +| postgres.image.pullPolicy | string | `"IfNotPresent"` | | +| postgres.service.type | string | `"ClusterIP"` | | +| postgres.service.port | int | `5432` | | | postgres.persistence.enabled | bool | `true` | | -| postgres.persistence.size | string | `"5Gi"` | | | postgres.persistence.storageClassName | string | `"manual"` | | -| postgres.probes.liveness.command[0] | string | `"pg_isready"` | | -| postgres.probes.liveness.command[1] | string | `"-U"` | | -| postgres.probes.liveness.command[2] | string | `"$(POSTGRES_USER)"` | | -| postgres.probes.liveness.failureThreshold | int | `5` | | -| postgres.probes.liveness.initialDelaySeconds | int | `10` | | -| postgres.probes.liveness.periodSeconds | int | `15` | | -| postgres.probes.liveness.successThreshold | int | `1` | | -| postgres.probes.liveness.timeoutSeconds | int | `3` | | -| postgres.probes.liveness.type | string | `"exec"` | | +| postgres.persistence.accessModes[0] | string | `"ReadWriteMany"` | | +| postgres.persistence.size | string | `"5Gi"` | | +| postgres.existingSecret | string | `""` | | +| postgres.credentials.database | string | `"postgresdb"` | | +| postgres.credentials.user | string | `"admin"` | | +| postgres.credentials.password | string | `"test123"` | | +| postgres.resources.limits.cpu | string | `"1000m"` | | +| postgres.resources.limits.memory | string | `"1Gi"` | | +| postgres.resources.requests.cpu | string | `"500m"` | | +| postgres.resources.requests.memory | string | `"64Mi"` | | +| postgres.probes.readiness.type | string | `"exec"` | | | postgres.probes.readiness.command[0] | string | `"pg_isready"` | | | postgres.probes.readiness.command[1] | string | `"-U"` | | | postgres.probes.readiness.command[2] | string | `"$(POSTGRES_USER)"` | | -| postgres.probes.readiness.failureThreshold | int | `3` | | | postgres.probes.readiness.initialDelaySeconds | int | `15` | | | postgres.probes.readiness.periodSeconds | int | `10` | | -| postgres.probes.readiness.successThreshold | int | `1` | | | postgres.probes.readiness.timeoutSeconds | int | `3` | | -| postgres.probes.readiness.type | string | `"exec"` | | -| postgres.resources.limits.cpu | string | `"1000m"` | | -| postgres.resources.limits.memory | string | `"1Gi"` | | -| postgres.resources.requests.cpu | string | `"500m"` | | -| postgres.resources.requests.memory | string | `"64Mi"` | | -| postgres.service.port | int | `5432` | | -| postgres.service.type | string | `"ClusterIP"` | | +| postgres.probes.readiness.successThreshold | int | `1` | | +| postgres.probes.readiness.failureThreshold | int | `3` | | +| postgres.probes.liveness.type | string | `"exec"` | | +| postgres.probes.liveness.command[0] | string | `"pg_isready"` | | +| postgres.probes.liveness.command[1] | string | `"-U"` | | +| postgres.probes.liveness.command[2] | string | `"$(POSTGRES_USER)"` | | +| postgres.probes.liveness.initialDelaySeconds | int | `10` | | +| postgres.probes.liveness.periodSeconds | int | `15` | | +| postgres.probes.liveness.timeoutSeconds | int | `3` | | +| postgres.probes.liveness.successThreshold | int | `1` | | +| postgres.probes.liveness.failureThreshold | int | `5` | | | redis.enabled | bool | `true` | | -| redis.image.pullPolicy | string | `"IfNotPresent"` | | | redis.image.repository | string | `"redis"` | | | redis.image.tag | string | `"latest"` | | -| redis.probes.liveness.command[0] | string | `"redis-cli"` | | -| redis.probes.liveness.command[1] | string | `"PING"` | | -| redis.probes.liveness.failureThreshold | int | `5` | | -| redis.probes.liveness.initialDelaySeconds | int | `5` | | -| redis.probes.liveness.periodSeconds | int | `15` | | -| redis.probes.liveness.successThreshold | int | `1` | | -| redis.probes.liveness.timeoutSeconds | int | `2` | | -| redis.probes.liveness.type | string | `"exec"` | | +| redis.image.pullPolicy | string | `"IfNotPresent"` | | +| redis.service.type | string | `"ClusterIP"` | | +| redis.service.port | int | `6379` | | +| redis.resources.limits.cpu | string | `"100m"` | | +| redis.resources.limits.memory | string | `"256Mi"` | | +| redis.resources.requests.cpu | string | `"50m"` | | +| redis.resources.requests.memory | string | `"16Mi"` | | +| redis.probes.readiness.type | string | `"exec"` | | | redis.probes.readiness.command[0] | string | `"redis-cli"` | | | redis.probes.readiness.command[1] | string | `"PING"` | | -| redis.probes.readiness.failureThreshold | int | `3` | | | redis.probes.readiness.initialDelaySeconds | int | `10` | | | redis.probes.readiness.periodSeconds | int | `10` | | -| redis.probes.readiness.successThreshold | int | `1` | | | redis.probes.readiness.timeoutSeconds | int | `2` | | -| redis.probes.readiness.type | string | `"exec"` | | -| redis.resources.limits.cpu | string | `"100m"` | | -| redis.resources.limits.memory | string | `"256Mi"` | | -| redis.resources.requests.cpu | string | `"50m"` | | -| redis.resources.requests.memory | string | `"16Mi"` | | -| redis.service.port | int | `6379` | | -| redis.service.type | string | `"ClusterIP"` | | +| redis.probes.readiness.successThreshold | int | `1` | | +| redis.probes.readiness.failureThreshold | int | `3` | | +| redis.probes.liveness.type | string | `"exec"` | | +| redis.probes.liveness.command[0] | string | `"redis-cli"` | | +| redis.probes.liveness.command[1] | string | `"PING"` | | +| redis.probes.liveness.initialDelaySeconds | int | `5` | | +| redis.probes.liveness.periodSeconds | int | `15` | | +| redis.probes.liveness.timeoutSeconds | int | `2` | | +| redis.probes.liveness.successThreshold | int | `1` | | +| redis.probes.liveness.failureThreshold | int | `5` | | +| pgadmin.enabled | bool | `true` | | +| pgadmin.image.repository | string | `"dpage/pgadmin4"` | | +| pgadmin.image.tag | string | `"latest"` | | +| pgadmin.image.pullPolicy | string | `"IfNotPresent"` | | +| pgadmin.service.type | string | `"ClusterIP"` | | +| pgadmin.service.port | int | `80` | | +| pgadmin.env.email | string | `"admin@example.com"` | | +| pgadmin.env.password | string | `"admin123"` | | +| pgadmin.resources.limits.cpu | string | `"200m"` | | +| pgadmin.resources.limits.memory | string | `"256Mi"` | | +| pgadmin.resources.requests.cpu | string | `"100m"` | | +| pgadmin.resources.requests.memory | string | `"128Mi"` | | +| pgadmin.probes.readiness.type | string | `"http"` | | +| pgadmin.probes.readiness.path | string | `"/misc/ping"` | | +| pgadmin.probes.readiness.port | int | `80` | | +| pgadmin.probes.readiness.initialDelaySeconds | int | `15` | | +| pgadmin.probes.readiness.periodSeconds | int | `10` | | +| pgadmin.probes.readiness.timeoutSeconds | int | `2` | | +| pgadmin.probes.readiness.successThreshold | int | `1` | | +| pgadmin.probes.readiness.failureThreshold | int | `3` | | +| pgadmin.probes.liveness.type | string | `"http"` | | +| pgadmin.probes.liveness.path | string | `"/misc/ping"` | | +| pgadmin.probes.liveness.port | int | `80` | | +| pgadmin.probes.liveness.initialDelaySeconds | int | `10` | | +| pgadmin.probes.liveness.periodSeconds | int | `15` | | +| pgadmin.probes.liveness.timeoutSeconds | int | `2` | | +| pgadmin.probes.liveness.successThreshold | int | `1` | | +| pgadmin.probes.liveness.failureThreshold | int | `5` | | | redisCommander.enabled | bool | `true` | | -| redisCommander.image.pullPolicy | string | `"IfNotPresent"` | | | redisCommander.image.repository | string | `"rediscommander/redis-commander"` | | | redisCommander.image.tag | string | `"latest"` | | -| redisCommander.probes.liveness.failureThreshold | int | `5` | | -| redisCommander.probes.liveness.initialDelaySeconds | int | `10` | | -| redisCommander.probes.liveness.path | string | `"/"` | | -| redisCommander.probes.liveness.periodSeconds | int | `15` | | -| redisCommander.probes.liveness.port | int | `8081` | | -| redisCommander.probes.liveness.successThreshold | int | `1` | | -| redisCommander.probes.liveness.timeoutSeconds | int | `2` | | -| redisCommander.probes.liveness.type | string | `"http"` | | -| redisCommander.probes.readiness.failureThreshold | int | `3` | | -| redisCommander.probes.readiness.initialDelaySeconds | int | `15` | | -| redisCommander.probes.readiness.path | string | `"/"` | | -| redisCommander.probes.readiness.periodSeconds | int | `10` | | -| redisCommander.probes.readiness.port | int | `8081` | | -| redisCommander.probes.readiness.successThreshold | int | `1` | | -| redisCommander.probes.readiness.timeoutSeconds | int | `2` | | -| redisCommander.probes.readiness.type | string | `"http"` | | +| redisCommander.image.pullPolicy | string | `"IfNotPresent"` | | +| redisCommander.service.type | string | `"ClusterIP"` | | +| redisCommander.service.port | int | `8081` | | | redisCommander.resources.limits.cpu | string | `"100m"` | | | redisCommander.resources.limits.memory | string | `"256Mi"` | | | redisCommander.resources.requests.cpu | string | `"50m"` | | | redisCommander.resources.requests.memory | string | `"128Mi"` | | -| redisCommander.service.port | int | `8081` | | -| redisCommander.service.type | string | `"ClusterIP"` | | +| redisCommander.probes.readiness.type | string | `"http"` | | +| redisCommander.probes.readiness.path | string | `"/"` | | +| redisCommander.probes.readiness.port | int | `8081` | | +| redisCommander.probes.readiness.initialDelaySeconds | int | `15` | | +| redisCommander.probes.readiness.periodSeconds | int | `10` | | +| redisCommander.probes.readiness.timeoutSeconds | int | `2` | | +| redisCommander.probes.readiness.successThreshold | int | `1` | | +| redisCommander.probes.readiness.failureThreshold | int | `3` | | +| redisCommander.probes.liveness.type | string | `"http"` | | +| redisCommander.probes.liveness.path | string | `"/"` | | +| redisCommander.probes.liveness.port | int | `8081` | | +| redisCommander.probes.liveness.initialDelaySeconds | int | `10` | | +| redisCommander.probes.liveness.periodSeconds | int | `15` | | +| redisCommander.probes.liveness.timeoutSeconds | int | `2` | | +| redisCommander.probes.liveness.successThreshold | int | `1` | | +| redisCommander.probes.liveness.failureThreshold | int | `5` | | +| mcpFastTimeServer.enabled | bool | `true` | | +| mcpFastTimeServer.replicaCount | int | `2` | | +| mcpFastTimeServer.image.repository | string | `"ghcr.io/ibm/fast-time-server"` | | +| mcpFastTimeServer.image.tag | string | `"0.5.0"` | | +| mcpFastTimeServer.image.pullPolicy | string | `"IfNotPresent"` | | +| mcpFastTimeServer.port | int | `8080` | | +| mcpFastTimeServer.ingress.enabled | bool | `true` | | +| mcpFastTimeServer.ingress.path | string | `"/fast-time"` | | +| mcpFastTimeServer.ingress.pathType | string | `"Prefix"` | | +| mcpFastTimeServer.ingress.servicePort | int | `80` | | +| mcpFastTimeServer.probes.readiness.type | string | `"http"` | | +| mcpFastTimeServer.probes.readiness.path | string | `"/health"` | | +| mcpFastTimeServer.probes.readiness.port | int | `8080` | | +| mcpFastTimeServer.probes.readiness.initialDelaySeconds | int | `3` | | +| mcpFastTimeServer.probes.readiness.periodSeconds | int | `10` | | +| mcpFastTimeServer.probes.readiness.timeoutSeconds | int | `2` | | +| mcpFastTimeServer.probes.readiness.successThreshold | int | `1` | | +| mcpFastTimeServer.probes.readiness.failureThreshold | int | `3` | | +| mcpFastTimeServer.probes.liveness.type | string | `"http"` | | +| mcpFastTimeServer.probes.liveness.path | string | `"/health"` | | +| mcpFastTimeServer.probes.liveness.port | int | `8080` | | +| mcpFastTimeServer.probes.liveness.initialDelaySeconds | int | `3` | | +| mcpFastTimeServer.probes.liveness.periodSeconds | int | `15` | | +| mcpFastTimeServer.probes.liveness.timeoutSeconds | int | `2` | | +| mcpFastTimeServer.probes.liveness.successThreshold | int | `1` | | +| mcpFastTimeServer.probes.liveness.failureThreshold | int | `3` | | +| mcpFastTimeServer.resources.limits.cpu | string | `"50m"` | | +| mcpFastTimeServer.resources.limits.memory | string | `"64Mi"` | | +| mcpFastTimeServer.resources.requests.cpu | string | `"25m"` | | +| mcpFastTimeServer.resources.requests.memory | string | `"10Mi"` | | diff --git a/charts/mcp-stack/values.schema.json b/charts/mcp-stack/values.schema.json index 9e79ef37..38df9682 100644 --- a/charts/mcp-stack/values.schema.json +++ b/charts/mcp-stack/values.schema.json @@ -373,17 +373,100 @@ "description": "Enable admin API endpoints", "default": "true" }, + "ENVIRONMENT": { + "type": "string", + "enum": ["development", "production"], + "description": "Deployment environment (affects security defaults)", + "default": "development" + }, + "APP_DOMAIN": { + "type": "string", + "description": "Domain for production CORS origins", + "default": "localhost" + }, "CORS_ENABLED": { "type": "string", "enum": ["true", "false"], "description": "Enable CORS processing", "default": "true" }, + "CORS_ALLOW_CREDENTIALS": { + "type": "string", + "enum": ["true", "false"], + "description": "Allow credentials in CORS requests", + "default": "true" + }, "ALLOWED_ORIGINS": { "type": "string", "description": "JSON array of allowed origins", "default": "[\"http://localhost\",\"http://localhost:4444\"]" }, + "SECURITY_HEADERS_ENABLED": { + "type": "string", + "enum": ["true", "false"], + "description": "Enable security headers middleware", + "default": "true" + }, + "X_FRAME_OPTIONS": { + "type": "string", + "enum": ["DENY", "SAMEORIGIN"], + "description": "X-Frame-Options header value", + "default": "DENY" + }, + "X_CONTENT_TYPE_OPTIONS_ENABLED": { + "type": "string", + "enum": ["true", "false"], + "description": "Enable X-Content-Type-Options header", + "default": "true" + }, + "X_XSS_PROTECTION_ENABLED": { + "type": "string", + "enum": ["true", "false"], + "description": "Enable X-XSS-Protection header", + "default": "true" + }, + "X_DOWNLOAD_OPTIONS_ENABLED": { + "type": "string", + "enum": ["true", "false"], + "description": "Enable X-Download-Options header", + "default": "true" + }, + "HSTS_ENABLED": { + "type": "string", + "enum": ["true", "false"], + "description": "Enable HSTS header", + "default": "true" + }, + "HSTS_MAX_AGE": { + "type": "string", + "pattern": "^[0-9]+$", + "description": "HSTS max age in seconds", + "default": "31536000" + }, + "HSTS_INCLUDE_SUBDOMAINS": { + "type": "string", + "enum": ["true", "false"], + "description": "Include subdomains in HSTS", + "default": "true" + }, + "REMOVE_SERVER_HEADERS": { + "type": "string", + "enum": ["true", "false"], + "description": "Remove server identification headers", + "default": "true" + }, + "SECURE_COOKIES": { + "type": "string", + "enum": ["true", "false"], + "description": "Force secure cookie flags", + "default": "true" + }, + "COOKIE_SAMESITE": { + "type": "string", + "enum": ["strict", "lax", "none"], + "description": "Cookie SameSite attribute", + "default": "lax" + }, "SKIP_SSL_VERIFY": { "type": "string", "enum": ["true", "false"], diff --git a/charts/mcp-stack/values.yaml b/charts/mcp-stack/values.yaml index 98955371..cfa8b7e9 100644 --- a/charts/mcp-stack/values.yaml +++ b/charts/mcp-stack/values.yaml @@ -149,10 +149,29 @@ mcpContextForge: PROTOCOL_VERSION: 2025-03-26 MCPGATEWAY_UI_ENABLED: "true" # toggle Admin UI MCPGATEWAY_ADMIN_API_ENABLED: "true" # toggle Admin API endpoints + # ─ Security & CORS ─ + ENVIRONMENT: development # deployment environment (development/production) + APP_DOMAIN: localhost # domain for production CORS origins CORS_ENABLED: "true" # enable CORS processing in gateway + CORS_ALLOW_CREDENTIALS: "true" # allow credentials in CORS requests ALLOWED_ORIGINS: '["http://localhost","http://localhost:4444"]' # JSON list of allowed origins SKIP_SSL_VERIFY: "false" # skip TLS certificate verification on upstream calls + # ─ Security Headers ─ + SECURITY_HEADERS_ENABLED: "true" # enable security headers middleware + X_FRAME_OPTIONS: DENY # X-Frame-Options header value + X_CONTENT_TYPE_OPTIONS_ENABLED: "true" # enable X-Content-Type-Options + X_XSS_PROTECTION_ENABLED: "true" # enable X-XSS-Protection + X_DOWNLOAD_OPTIONS_ENABLED: "true" # enable X-Download-Options + HSTS_ENABLED: "true" # enable HSTS header + HSTS_MAX_AGE: "31536000" # HSTS max age in seconds (1 year) + HSTS_INCLUDE_SUBDOMAINS: "true" # include subdomains in HSTS + REMOVE_SERVER_HEADERS: "true" # remove server identification headers + + # ─ Cookie Security ─ + SECURE_COOKIES: "true" # force secure cookie flags + COOKIE_SAMESITE: lax # cookie SameSite attribute + # ─ Logging ─ LOG_LEVEL: INFO # DEBUG, INFO, WARNING, ERROR, CRITICAL LOG_FORMAT: json # json or text format diff --git a/docker-compose.yml b/docker-compose.yml index 5eb267b3..cadf796e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,6 +51,10 @@ services: - BASIC_AUTH_PASSWORD=changeme - MCPGATEWAY_UI_ENABLED=true - MCPGATEWAY_ADMIN_API_ENABLED=true + # Security configuration (using defaults) + - ENVIRONMENT=development + - SECURITY_HEADERS_ENABLED=true + - CORS_ALLOW_CREDENTIALS=true # - SSL=true # - CERT_FILE=/app/certs/cert.pem # - KEY_FILE=/app/certs/key.pem @@ -72,12 +76,12 @@ services: # condition: service_completed_successfully healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:4444/health"] + test: ["CMD", "python3", "-c", "import urllib.request; import json; resp = urllib.request.urlopen('http://localhost:4444/health', timeout=5); data = json.loads(resp.read()); exit(0 if data.get('status') == 'healthy' else 1)"] #test: ["CMD", "curl", "-f", "https://localhost:4444/health"] interval: 30s timeout: 10s retries: 5 - start_period: 20s + start_period: 30s # volumes: # - ./certs:/app/certs:ro # mount certs folder read-only diff --git a/docs/docs/architecture/adr/014-security-headers-cors-middleware.md b/docs/docs/architecture/adr/014-security-headers-cors-middleware.md new file mode 100644 index 00000000..96346b3c --- /dev/null +++ b/docs/docs/architecture/adr/014-security-headers-cors-middleware.md @@ -0,0 +1,238 @@ +# ADR-0014: Security Headers and Environment-Aware CORS Middleware + +- *Status:* Accepted +- *Date:* 2025-08-17 +- *Deciders:* Core Engineering Team +- *Issues:* [#344](https://github.com/IBM/mcp-context-forge/issues/344), [#533](https://github.com/IBM/mcp-context-forge/issues/533) +- *Related:* Addresses all 9 security headers identified by nodejsscan + +## Context + +The MCP Gateway needed comprehensive security headers and proper CORS configuration to prevent common web attacks including XSS, clickjacking, MIME sniffing, and cross-origin attacks. Additionally, the nodejsscan static analysis tool identified 9 missing security headers specifically for the Admin UI and static assets. + +The previous implementation had: +- Basic CORS middleware with wildcard origins in some configurations +- Limited security headers only in the DocsAuthMiddleware +- No comprehensive security header implementation +- Manual CORS origin configuration without environment awareness +- Admin UI cookie settings without proper security attributes +- No static analysis tool compatibility + +Security requirements included: +- **Essential security headers** for all responses (issue #344) +- **Configurable security headers** for Admin UI and static assets (issue #533) +- **Environment-aware CORS** configuration for development vs production +- **Secure cookie handling** for authentication +- **Admin UI compatibility** with Content Security Policy +- **Static analysis compatibility** for nodejsscan and similar tools +- **Backward compatibility** with existing configurations + +## Decision + +We implemented a comprehensive security middleware solution with the following components: + +### 1. SecurityHeadersMiddleware + +Created `mcpgateway/middleware/security_headers.py` that automatically adds essential security headers to all responses: + +```python +# Essential security headers +response.headers["X-Content-Type-Options"] = "nosniff" +response.headers["X-Frame-Options"] = "DENY" +response.headers["X-XSS-Protection"] = "0" # Modern browsers use CSP +response.headers["X-Download-Options"] = "noopen" # Prevent IE downloads +response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + +# Content Security Policy (Admin UI compatible) +csp_directives = [ + "default-src 'self'", + "script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdnjs.cloudflare.com https://cdn.tailwindcss.com https://cdn.jsdelivr.net", + "style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com", + "img-src 'self' data: https:", + "font-src 'self' data:", + "connect-src 'self' ws: wss: https:", + "frame-ancestors 'none'" +] + +# HSTS for HTTPS connections +if request.url.scheme == "https" or request.headers.get("X-Forwarded-Proto") == "https": + response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + +# Remove sensitive headers +del response.headers["X-Powered-By"] # if present +del response.headers["Server"] # if present +``` + +### 2. Environment-Aware CORS Configuration + +Enhanced CORS setup in `mcpgateway/main.py` with automatic origin configuration: + +**Development Environment:** +- Automatically configures origins for common development ports: localhost:3000, localhost:8080, gateway port +- Includes both `localhost` and `127.0.0.1` variants +- Allows HTTP origins for development convenience + +**Production Environment:** +- Constructs HTTPS origins from `APP_DOMAIN` setting +- Creates origins: `https://{domain}`, `https://app.{domain}`, `https://admin.{domain}` +- Enforces HTTPS-only origins +- Never uses wildcard origins + +### 3. Secure Cookie Utilities + +Added `mcpgateway/utils/security_cookies.py` with functions for secure authentication: + +```python +def set_auth_cookie(response: Response, token: str, remember_me: bool = False): + use_secure = (settings.environment == "production") or settings.secure_cookies + response.set_cookie( + key="jwt_token", + value=token, + max_age=30 * 24 * 3600 if remember_me else 3600, + httponly=True, # Prevents JavaScript access + secure=use_secure, # HTTPS only in production + samesite=settings.cookie_samesite, # CSRF protection + path="/" + ) +``` + +### 4. Configurable Security Headers + +Added comprehensive configuration options to `mcpgateway/config.py` for all security headers: + +```python +# Environment awareness +environment: str = Field(default="development", env="ENVIRONMENT") +app_domain: str = Field(default="localhost", env="APP_DOMAIN") + +# Cookie Security +secure_cookies: bool = Field(default=True, env="SECURE_COOKIES") +cookie_samesite: str = Field(default="lax", env="COOKIE_SAMESITE") + +# CORS Configuration +cors_allow_credentials: bool = Field(default=True, env="CORS_ALLOW_CREDENTIALS") + +# Security Headers Configuration (issue #533) +security_headers_enabled: bool = Field(default=True, env="SECURITY_HEADERS_ENABLED") +x_frame_options: str = Field(default="DENY", env="X_FRAME_OPTIONS") +x_content_type_options_enabled: bool = Field(default=True, env="X_CONTENT_TYPE_OPTIONS_ENABLED") +x_xss_protection_enabled: bool = Field(default=True, env="X_XSS_PROTECTION_ENABLED") +x_download_options_enabled: bool = Field(default=True, env="X_DOWNLOAD_OPTIONS_ENABLED") +hsts_enabled: bool = Field(default=True, env="HSTS_ENABLED") +hsts_max_age: int = Field(default=31536000, env="HSTS_MAX_AGE") +hsts_include_subdomains: bool = Field(default=True, env="HSTS_INCLUDE_SUBDOMAINS") +remove_server_headers: bool = Field(default=True, env="REMOVE_SERVER_HEADERS") +``` + +### 5. Static Analysis Tool Compatibility + +Added security meta tags to `mcpgateway/templates/admin.html` for static analysis tool compatibility: + +```html + + + + + + +``` + +### 6. Enhanced Static Analysis + +Updated Makefile to scan both static files and templates: +```makefile +nodejsscan: + @$(VENV_DIR)/bin/nodejsscan --directory ./mcpgateway/static --directory ./mcpgateway/templates || true +``` + +## Consequences + +### ✅ Benefits + +- **Comprehensive Protection**: All responses include essential security headers +- **Automatic Configuration**: CORS origins are automatically configured based on environment +- **Admin UI Compatible**: CSP allows required CDN resources while maintaining security +- **Production Ready**: Secure defaults for production deployments +- **Development Friendly**: Permissive localhost origins for development +- **Backward Compatible**: Existing configurations continue to work +- **Cookie Security**: Authentication cookies automatically configured with security flags +- **HTTPS Detection**: HSTS header added automatically when HTTPS is detected + +### ❌ Trade-offs + +- **CSP Flexibility**: Using 'unsafe-inline' and 'unsafe-eval' for Admin UI compatibility +- **CDN Dependencies**: CSP allows specific external CDN domains +- **Configuration Complexity**: More environment variables to configure +- **Development Overhead**: Additional middleware processing on every request + +### 🔄 Maintenance + +- **CSP Updates**: May need updates if Admin UI adds new external dependencies +- **CDN Changes**: CSP must be updated if CDN URLs change +- **Security Reviews**: Periodic review of CSP directives for security improvements +- **Browser Updates**: Monitor browser CSP implementation changes + +## Alternatives Considered + +| Alternative | Why Not Chosen | +|------------|----------------| +| **Manual CORS configuration only** | Error-prone and inconsistent across environments | +| **Strict CSP without Admin UI support** | Would break existing Admin UI functionality | +| **Separate middleware for each header** | More complex and harder to maintain | +| **Runtime-configurable CSP** | Added complexity with minimal benefit | +| **No security headers** | Unacceptable security posture for production | +| **Environment-specific builds** | More complex deployment and maintenance | + +## Implementation Details + +### Middleware Order +```python +# Order matters - security headers should be added after CORS +app.add_middleware(CORSMiddleware, ...) # 1. CORS first +app.add_middleware(SecurityHeadersMiddleware) # 2. Security headers +app.add_middleware(DocsAuthMiddleware) # 3. Auth protection +``` + +### Environment Detection +- Uses `ENVIRONMENT` setting to determine development vs production mode +- Falls back to safe defaults if environment not specified +- Only applies automatic origins when using default configuration + +### CSP Design Decisions +- **'unsafe-inline'**: Required for Tailwind CSS inline styles and Alpine.js +- **'unsafe-eval'**: Required for some JavaScript frameworks used in Admin UI +- **Specific CDN domains**: Whitelisted known-good CDN sources instead of wildcard +- **'frame-ancestors none'**: Prevents all framing to prevent clickjacking + +### iframe Embedding Configuration +By default, iframe embedding is **disabled** for security via `X-Frame-Options: DENY` and `frame-ancestors 'none'`. To enable iframe embedding: + +1. **Same-domain embedding**: Set `X_FRAME_OPTIONS=SAMEORIGIN` +2. **Specific domain embedding**: Set `X_FRAME_OPTIONS=ALLOW-FROM https://trusted-domain.com` +3. **Disable frame protection**: Set `X_FRAME_OPTIONS=""` (not recommended) + +**Note**: When changing X-Frame-Options, also consider updating the CSP `frame-ancestors` directive for comprehensive browser support. + +## Testing Strategy + +Implemented comprehensive test coverage (42 new tests): +- **Security headers validation** across all endpoints +- **CORS behavior testing** for allowed and blocked origins +- **Environment-aware configuration** testing +- **Cookie security attributes** validation +- **Production security posture** verification +- **CSP directive structure** validation +- **HSTS behavior** testing + +## Future Enhancements + +Potential improvements for future iterations: +- **CSP Nonces**: Replace 'unsafe-inline' with nonces for dynamic content +- **Subresource Integrity**: Add SRI for external CDN resources +- **CSP Violation Reporting**: Implement CSP violation reporting endpoint +- **Per-Route CSP**: Different CSP policies for different endpoints +- **Security Header Compliance**: Monitoring dashboard for header compliance + +## Status + +This security headers and CORS middleware implementation is **accepted and implemented** as of version 0.5.0, providing comprehensive security coverage while maintaining compatibility with existing functionality. diff --git a/docs/docs/architecture/adr/index.md b/docs/docs/architecture/adr/index.md index b1e89eec..a29bda35 100644 --- a/docs/docs/architecture/adr/index.md +++ b/docs/docs/architecture/adr/index.md @@ -14,5 +14,6 @@ This page tracks all significant design decisions made for the MCP Gateway proje | 0008 | Federation & Auto-Discovery via DNS-SD | Accepted | Federation | 2025-02-21 | | 0009 | Built-in Health Checks & Self-Monitoring | Accepted | Operations | 2025-02-21 | | 0010 | Observability via Prometheus, Structured Logs | Accepted | Observability | 2025-02-21 | +| 0014 | Security Headers & Environment-Aware CORS Middleware | Accepted | Security | 2025-08-17 | > ✳️ Add new decisions chronologically and link to them from this table. diff --git a/docs/docs/architecture/roadmap.md b/docs/docs/architecture/roadmap.md index 20e74b73..a2ca751a 100644 --- a/docs/docs/architecture/roadmap.md +++ b/docs/docs/architecture/roadmap.md @@ -300,7 +300,7 @@ - [**#538**](https://github.com/IBM/mcp-context-forge/issues/538) - [SECURITY FEATURE] Content Size & Type Security Limits for Resources & Prompts - [**#537**](https://github.com/IBM/mcp-context-forge/issues/537) - Simple Endpoint Feature Flags (selectively enable or disable tools, resources, prompts, servers, gateways, roots) - [**#534**](https://github.com/IBM/mcp-context-forge/issues/534) - Add Security Configuration Validation and Startup Checks - - [**#533**](https://github.com/IBM/mcp-context-forge/issues/533) - Add Additional Configurable Security Headers to APIs for Admin UI + - ✅ [**#533**](https://github.com/IBM/mcp-context-forge/issues/533) - Add Additional Configurable Security Headers to APIs for Admin UI - [**#342**](https://github.com/IBM/mcp-context-forge/issues/342) - Implement database-level security constraints and SQL injection prevention - [**#284**](https://github.com/IBM/mcp-context-forge/issues/284) - LDAP / Active-Directory Integration - [**#282**](https://github.com/IBM/mcp-context-forge/issues/282) - Per-Virtual-Server API Keys with Scoped Access @@ -369,7 +369,7 @@ - [**#398**](https://github.com/IBM/mcp-context-forge/issues/398) - Enforce pre-commit targets for doctest coverage, pytest coverage, pylint score 10/10, flake8 pass and add badges - [**#391**](https://github.com/IBM/mcp-context-forge/issues/391) - Setup SonarQube quality gate (draft) - [**#377**](https://github.com/IBM/mcp-context-forge/issues/377) - Fix PostgreSQL Volume Name Conflicts in Helm Chart (draft) - - [**#344**](https://github.com/IBM/mcp-context-forge/issues/344) - Implement additional security headers and CORS configuration + - ✅ [**#344**](https://github.com/IBM/mcp-context-forge/issues/344) - Implement additional security headers and CORS configuration - [**#318**](https://github.com/IBM/mcp-context-forge/issues/318) - Publish Agents and Tools that leverage codebase and templates (draft) - [**#312**](https://github.com/IBM/mcp-context-forge/issues/312) - End-to-End MCP Gateway Stack Testing Harness (mcpgateway, translate, wrapper, mcp-servers) - [**#281**](https://github.com/IBM/mcp-context-forge/issues/281) - Set up contract testing with Pact (pact-python) including Makefile and GitHub Actions targets diff --git a/docs/docs/architecture/security-features.md b/docs/docs/architecture/security-features.md index d9b44831..73003155 100644 --- a/docs/docs/architecture/security-features.md +++ b/docs/docs/architecture/security-features.md @@ -124,7 +124,16 @@ * **Configuration Validation** - Schema enforcement with startup security checks ([#285](https://github.com/IBM/mcp-context-forge/issues/285), [#534](https://github.com/IBM/mcp-context-forge/issues/534)) 🚧 -* **Security Headers** - Configurable headers and CORS policies ([#344](https://github.com/IBM/mcp-context-forge/issues/344), [#533](https://github.com/IBM/mcp-context-forge/issues/533)) 🚧 +* **Security Headers & Configurable Admin UI Security** - Comprehensive security headers with full configurability (✅ [#344](https://github.com/IBM/mcp-context-forge/issues/344), ✅ [#533](https://github.com/IBM/mcp-context-forge/issues/533)) + - **X-Content-Type-Options: nosniff** - Prevents MIME type sniffing attacks (configurable) + - **X-Frame-Options: DENY** - Prevents clickjacking attacks (configurable: DENY/SAMEORIGIN) + - **X-Download-Options: noopen** - Prevents IE download execution (configurable) + - **Content-Security-Policy** - Comprehensive XSS and injection protection (Admin UI compatible) + - **Strict-Transport-Security** - Forces HTTPS connections (configurable max-age & subdomains) + - **Environment-aware CORS** - Automatic origin configuration for dev/production + - **Secure cookies** - HttpOnly, Secure, SameSite attributes for authentication + - **Static analysis compatibility** - Meta tags complement HTTP headers for nodejsscan + - **15 configuration options** - Individual control over all security features * **Well-Known URI Handler** - security.txt and robots.txt support ([#540](https://github.com/IBM/mcp-context-forge/issues/540)) 🚧 diff --git a/docs/docs/manage/securing.md b/docs/docs/manage/securing.md index 152bb5fb..6e6fdce8 100644 --- a/docs/docs/manage/securing.md +++ b/docs/docs/manage/securing.md @@ -27,13 +27,26 @@ MCPGATEWAY_ENABLE_PROMPTS=false # If not using prompts MCPGATEWAY_ENABLE_RESOURCES=false # If not using resources ``` -### 2. Enable Authentication +### 2. Enable Authentication & Security ```bash # Configure strong authentication MCPGATEWAY_AUTH_ENABLED=true MCPGATEWAY_AUTH_USERNAME=custom-username # Change from default MCPGATEWAY_AUTH_PASSWORD=strong-password-here # Use secrets manager + +# Set environment for security defaults +ENVIRONMENT=production + +# Configure domain for CORS +APP_DOMAIN=yourdomain.com + +# Ensure secure cookies (automatic in production) +SECURE_COOKIES=true +COOKIE_SAMESITE=strict + +# Configure CORS (auto-configured based on APP_DOMAIN in production) +CORS_ALLOW_CREDENTIALS=true ``` ### 3. Network Security @@ -41,8 +54,10 @@ MCPGATEWAY_AUTH_PASSWORD=strong-password-here # Use secrets manager - [ ] Configure TLS/HTTPS with valid certificates - [ ] Implement firewall rules and network policies - [ ] Use internal-only endpoints where possible -- [ ] Configure appropriate CORS policies +- [ ] Configure appropriate CORS policies (auto-configured by ENVIRONMENT setting) - [ ] Set up rate limiting per endpoint/client +- [ ] Verify security headers are present (automatically added by SecurityHeadersMiddleware) +- [ ] Configure iframe embedding policy (X_FRAME_OPTIONS=DENY by default, change to SAMEORIGIN if needed) ### 4. Container Security diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index a7a47a51..460b1b52 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -79,6 +79,7 @@ from mcpgateway.utils.error_formatter import ErrorFormatter from mcpgateway.utils.passthrough_headers import PassthroughHeadersError from mcpgateway.utils.retry_manager import ResilientHttpClient +from mcpgateway.utils.security_cookies import set_auth_cookie from mcpgateway.utils.verify_credentials import require_auth, require_basic_auth # Import the shared logging service from main @@ -1583,7 +1584,8 @@ async def admin_ui( }, ) - response.set_cookie(key="jwt_token", value=jwt_token, httponly=True, secure=False, samesite="Strict") # JavaScript CAN'T read it # only over HTTPS # or "Lax" per your needs + # Use secure cookie utility for proper security attributes + set_auth_cookie(response, jwt_token, remember_me=False) return response diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 00ffe35e..916dfb8f 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -159,6 +159,30 @@ class Settings(BaseSettings): skip_ssl_verify: bool = False cors_enabled: bool = True + # Environment + environment: str = Field(default="development", env="ENVIRONMENT") + + # Domain configuration + app_domain: str = Field(default="localhost", env="APP_DOMAIN") + + # Security settings + secure_cookies: bool = Field(default=True, env="SECURE_COOKIES") + cookie_samesite: str = Field(default="lax", env="COOKIE_SAMESITE") + + # CORS settings + cors_allow_credentials: bool = Field(default=True, env="CORS_ALLOW_CREDENTIALS") + + # Security Headers Configuration + security_headers_enabled: bool = Field(default=True, env="SECURITY_HEADERS_ENABLED") + x_frame_options: str = Field(default="DENY", env="X_FRAME_OPTIONS") + x_content_type_options_enabled: bool = Field(default=True, env="X_CONTENT_TYPE_OPTIONS_ENABLED") + x_xss_protection_enabled: bool = Field(default=True, env="X_XSS_PROTECTION_ENABLED") + x_download_options_enabled: bool = Field(default=True, env="X_DOWNLOAD_OPTIONS_ENABLED") + hsts_enabled: bool = Field(default=True, env="HSTS_ENABLED") + hsts_max_age: int = Field(default=31536000, env="HSTS_MAX_AGE") # 1 year + hsts_include_subdomains: bool = Field(default=True, env="HSTS_INCLUDE_SUBDOMAINS") + remove_server_headers: bool = Field(default=True, env="REMOVE_SERVER_HEADERS") + # For allowed_origins, strip '' to ensure we're passing on valid JSON via env # Tell pydantic *not* to touch this env var - our validator will. allowed_origins: Annotated[Set[str], NoDecode] = { @@ -673,6 +697,23 @@ def __init__(self, **kwargs): # Safer defaults without Authorization header self.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + # Configure environment-aware CORS origins if not explicitly set via env or kwargs + # Only apply defaults if using the default allowed_origins value + if not os.environ.get("ALLOWED_ORIGINS") and "allowed_origins" not in kwargs and self.allowed_origins == {"http://localhost", "http://localhost:4444"}: + if self.environment == "development": + self.allowed_origins = { + "http://localhost", + "http://localhost:3000", + "http://localhost:8080", + "http://127.0.0.1:3000", + "http://127.0.0.1:8080", + f"http://localhost:{self.port}", + f"http://127.0.0.1:{self.port}", + } + else: + # Production origins - construct from app_domain + self.allowed_origins = {f"https://{self.app_domain}", f"https://app.{self.app_domain}", f"https://admin.{self.app_domain}"} + # Validate proxy auth configuration if not self.mcp_client_auth_enabled and not self.trust_proxy_auth: logger.warning( diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 5d3ca163..44a0fd4a 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -59,6 +59,7 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, refresh_slugs_on_startup, SessionLocal from mcpgateway.handlers.sampling import SamplingHandler +from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, ResourceContent, Root from mcpgateway.observability import init_telemetry from mcpgateway.plugins import PluginManager, PluginViolationError @@ -502,17 +503,27 @@ async def __call__(self, scope, receive, send): await self.application(scope, receive, send) -# Configure CORS +# Configure CORS with environment-aware origins +cors_origins = list(settings.allowed_origins) if settings.allowed_origins else [] + +# Ensure we never use wildcard in production +if settings.environment == "production" and not cors_origins: + logger.warning("No CORS origins configured for production environment. CORS will be disabled.") + cors_origins = [] + app.add_middleware( CORSMiddleware, - allow_origins=["*"] if not settings.allowed_origins else list(settings.allowed_origins), - allow_credentials=True, - allow_methods=["*"], + allow_origins=cors_origins, + allow_credentials=settings.cors_allow_credentials, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], - expose_headers=["Content-Type", "Content-Length"], + expose_headers=["Content-Length", "X-Request-ID"], ) +# Add security headers middleware +app.add_middleware(SecurityHeadersMiddleware) + # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) diff --git a/mcpgateway/middleware/__init__.py b/mcpgateway/middleware/__init__.py new file mode 100644 index 00000000..a72ce23c --- /dev/null +++ b/mcpgateway/middleware/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Middleware package for MCP Gateway. +""" diff --git a/mcpgateway/middleware/security_headers.py b/mcpgateway/middleware/security_headers.py new file mode 100644 index 00000000..b4d1124b --- /dev/null +++ b/mcpgateway/middleware/security_headers.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Headers Middleware for MCP Gateway. + +This module implements essential security headers to prevent common attacks including +XSS, clickjacking, MIME sniffing, and cross-origin attacks. +""" + +# Third-Party +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +# First-Party +from mcpgateway.config import settings + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """ + Security headers middleware that adds essential security headers to all responses. + + This middleware implements security best practices by adding headers that help + prevent various types of attacks and security vulnerabilities. + + Security headers added: + - X-Content-Type-Options: Prevents MIME type sniffing + - X-Frame-Options: Prevents clickjacking attacks + - X-XSS-Protection: Disables legacy XSS protection (modern browsers use CSP) + - Referrer-Policy: Controls referrer information sent with requests + - Content-Security-Policy: Prevents XSS and other code injection attacks + - Strict-Transport-Security: Forces HTTPS connections (when appropriate) + + Sensitive headers removed: + - X-Powered-By: Removes server technology disclosure + - Server: Removes server version information + """ + + async def dispatch(self, request: Request, call_next) -> Response: + """ + Process the request and add security headers to the response. + + Args: + request: The incoming HTTP request + call_next: The next middleware or endpoint handler + + Returns: + Response with security headers added + """ + response = await call_next(request) + + # Only apply security headers if enabled + if not settings.security_headers_enabled: + return response + + # Essential security headers (configurable) + if settings.x_content_type_options_enabled: + response.headers["X-Content-Type-Options"] = "nosniff" + + if settings.x_frame_options: + response.headers["X-Frame-Options"] = settings.x_frame_options + + if settings.x_xss_protection_enabled: + response.headers["X-XSS-Protection"] = "0" # Modern browsers use CSP instead + + if settings.x_download_options_enabled: + response.headers["X-Download-Options"] = "noopen" # Prevent IE from executing downloads + + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + + # Content Security Policy + # This CSP is designed to work with the Admin UI while providing security + csp_directives = [ + "default-src 'self'", + "script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdnjs.cloudflare.com https://cdn.tailwindcss.com https://cdn.jsdelivr.net", + "style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com", + "img-src 'self' data: https:", + "font-src 'self' data:", + "connect-src 'self' ws: wss: https:", + "frame-ancestors 'none'", + ] + response.headers["Content-Security-Policy"] = "; ".join(csp_directives) + ";" + + # HSTS for HTTPS connections (configurable) + if settings.hsts_enabled and (request.url.scheme == "https" or request.headers.get("X-Forwarded-Proto") == "https"): + hsts_value = f"max-age={settings.hsts_max_age}" + if settings.hsts_include_subdomains: + hsts_value += "; includeSubDomains" + response.headers["Strict-Transport-Security"] = hsts_value + + # Remove sensitive headers that might disclose server information (configurable) + if settings.remove_server_headers: + if "X-Powered-By" in response.headers: + del response.headers["X-Powered-By"] + if "Server" in response.headers: + del response.headers["Server"] + + return response diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 423b822f..608b8b27 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -171,7 +171,7 @@ async def get_top_prompts(self, db: Session, limit: int = 5) -> List[TopPerforme case( ( func.count(PromptMetric.id) > 0, # pylint: disable=not-callable - func.sum(case((PromptMetric.is_success == 1, 1), else_=0)).cast(Float) / func.count(PromptMetric.id) * 100, # pylint: disable=not-callable + func.sum(case((PromptMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(PromptMetric.id) * 100, # pylint: disable=not-callable ), else_=None, ).label("success_rate"), @@ -1077,8 +1077,8 @@ async def aggregate_metrics(self, db: Session) -> Dict[str, Any]: """ total = db.execute(select(func.count(PromptMetric.id))).scalar() or 0 # pylint: disable=not-callable - successful = db.execute(select(func.count(PromptMetric.id)).where(PromptMetric.is_success == 1)).scalar() or 0 # pylint: disable=not-callable - failed = db.execute(select(func.count(PromptMetric.id)).where(PromptMetric.is_success == 0)).scalar() or 0 # pylint: disable=not-callable + successful = db.execute(select(func.count(PromptMetric.id)).where(PromptMetric.is_success.is_(True))).scalar() or 0 # pylint: disable=not-callable + failed = db.execute(select(func.count(PromptMetric.id)).where(PromptMetric.is_success.is_(False))).scalar() or 0 # pylint: disable=not-callable failure_rate = failed / total if total > 0 else 0.0 min_rt = db.execute(select(func.min(PromptMetric.response_time))).scalar() max_rt = db.execute(select(func.max(PromptMetric.response_time))).scalar() diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 891b4b38..46b6e953 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -167,7 +167,7 @@ async def get_top_resources(self, db: Session, limit: int = 5) -> List[TopPerfor case( ( func.count(ResourceMetric.id) > 0, # pylint: disable=not-callable - func.sum(case((ResourceMetric.is_success == 1, 1), else_=0)).cast(Float) / func.count(ResourceMetric.id) * 100, # pylint: disable=not-callable + func.sum(case((ResourceMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ResourceMetric.id) * 100, # pylint: disable=not-callable ), else_=None, ).label("success_rate"), @@ -1167,9 +1167,9 @@ async def aggregate_metrics(self, db: Session) -> ResourceMetrics: """ total_executions = db.execute(select(func.count()).select_from(ResourceMetric)).scalar() or 0 # pylint: disable=not-callable - successful_executions = db.execute(select(func.count()).select_from(ResourceMetric).where(ResourceMetric.is_success == 1)).scalar() or 0 # pylint: disable=not-callable + successful_executions = db.execute(select(func.count()).select_from(ResourceMetric).where(ResourceMetric.is_success.is_(True))).scalar() or 0 # pylint: disable=not-callable - failed_executions = db.execute(select(func.count()).select_from(ResourceMetric).where(ResourceMetric.is_success == 0)).scalar() or 0 # pylint: disable=not-callable + failed_executions = db.execute(select(func.count()).select_from(ResourceMetric).where(ResourceMetric.is_success.is_(False))).scalar() or 0 # pylint: disable=not-callable min_response_time = db.execute(select(func.min(ResourceMetric.response_time))).scalar() diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index 46a1db0d..f0813df9 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -157,7 +157,7 @@ async def get_top_servers(self, db: Session, limit: int = 5) -> List[TopPerforme case( ( func.count(ServerMetric.id) > 0, # pylint: disable=not-callable - func.sum(case((ServerMetric.is_success == 1, 1), else_=0)).cast(Float) / func.count(ServerMetric.id) * 100, # pylint: disable=not-callable + func.sum(case((ServerMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ServerMetric.id) * 100, # pylint: disable=not-callable ), else_=None, ).label("success_rate"), @@ -833,9 +833,9 @@ async def aggregate_metrics(self, db: Session) -> ServerMetrics: """ total_executions = db.execute(select(func.count()).select_from(ServerMetric)).scalar() or 0 # pylint: disable=not-callable - successful_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success == 1)).scalar() or 0 # pylint: disable=not-callable + successful_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success.is_(True))).scalar() or 0 # pylint: disable=not-callable - failed_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success == 0)).scalar() or 0 # pylint: disable=not-callable + failed_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success.is_(False))).scalar() or 0 # pylint: disable=not-callable min_response_time = db.execute(select(func.min(ServerMetric.response_time))).scalar() diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 2f3632a9..155f8e51 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -225,7 +225,7 @@ async def get_top_tools(self, db: Session, limit: int = 5) -> List[TopPerformer] case( ( func.count(ToolMetric.id) > 0, # pylint: disable=not-callable - func.sum(case((ToolMetric.is_success == 1, 1), else_=0)).cast(Float) / func.count(ToolMetric.id) * 100, # pylint: disable=not-callable + func.sum(case((ToolMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ToolMetric.id) * 100, # pylint: disable=not-callable ), else_=None, ).label("success_rate"), @@ -1182,8 +1182,8 @@ async def aggregate_metrics(self, db: Session) -> Dict[str, Any]: """ total = db.execute(select(func.count(ToolMetric.id))).scalar() or 0 # pylint: disable=not-callable - successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success == 1)).scalar() or 0 # pylint: disable=not-callable - failed = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success == 0)).scalar() or 0 # pylint: disable=not-callable + successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success.is_(True))).scalar() or 0 # pylint: disable=not-callable + failed = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success.is_(False))).scalar() or 0 # pylint: disable=not-callable failure_rate = failed / total if total > 0 else 0.0 min_rt = db.execute(select(func.min(ToolMetric.response_time))).scalar() max_rt = db.execute(select(func.max(ToolMetric.response_time))).scalar() diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index aeb1c7a1..f6ce3d7b 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -9,6 +9,14 @@ MCP Gateway Admin + + + + + + + + None: + """ + Set authentication cookie with security flags. + + Configures the JWT token as a secure HTTP-only cookie with appropriate + security attributes to prevent XSS and CSRF attacks. + + Args: + response: FastAPI response object to set the cookie on + token: JWT token to store in the cookie + remember_me: If True, sets longer expiration time (30 days vs 1 hour) + + Security attributes set: + - httponly: Prevents JavaScript access to the cookie + - secure: HTTPS only in production environments + - samesite: CSRF protection (configurable, defaults to 'lax') + - path: Cookie scope limitation + - max_age: Automatic expiration + """ + # Set expiration based on remember_me preference + max_age = 30 * 24 * 3600 if remember_me else 3600 # 30 days or 1 hour + + # Determine if we should use secure flag + # In production or when explicitly configured, require HTTPS + use_secure = (settings.environment == "production") or settings.secure_cookies + + response.set_cookie( + key="jwt_token", + value=token, + max_age=max_age, + httponly=True, # Prevents JavaScript access + secure=use_secure, # HTTPS only in production + samesite=settings.cookie_samesite, # CSRF protection + path="/", # Cookie scope + ) + + +def clear_auth_cookie(response: Response) -> None: + """ + Clear authentication cookie securely. + + Removes the JWT token cookie by setting it to expire immediately + with the same security attributes used when setting it. + + Args: + response: FastAPI response object to clear the cookie from + """ + # Use same security settings as when setting the cookie + use_secure = (settings.environment == "production") or settings.secure_cookies + + response.delete_cookie(key="jwt_token", path="/", secure=use_secure, httponly=True, samesite=settings.cookie_samesite) + + +def set_session_cookie(response: Response, session_id: str, max_age: int = 3600) -> None: + """ + Set session cookie with security flags. + + Configures a session ID cookie with appropriate security attributes. + + Args: + response: FastAPI response object to set the cookie on + session_id: Session identifier to store in the cookie + max_age: Cookie expiration time in seconds (default: 1 hour) + """ + use_secure = (settings.environment == "production") or settings.secure_cookies + + response.set_cookie( + key="session_id", + value=session_id, + max_age=max_age, + httponly=True, + secure=use_secure, + samesite=settings.cookie_samesite, + path="/", + ) + + +def clear_session_cookie(response: Response) -> None: + """ + Clear session cookie securely. + + Args: + response: FastAPI response object to clear the cookie from + """ + use_secure = (settings.environment == "production") or settings.secure_cookies + + response.delete_cookie(key="session_id", path="/", secure=use_secure, httponly=True, samesite=settings.cookie_samesite) diff --git a/tests/async/test_async_safety.py b/tests/async/test_async_safety.py index 8849c74f..84b0aa93 100644 --- a/tests/async/test_async_safety.py +++ b/tests/async/test_async_safety.py @@ -27,9 +27,12 @@ async def mock_operation(): results = await asyncio.gather(*tasks) end_time = time.time() + execution_time = end_time - start_time # Should complete in roughly 10ms, not 1000ms (100 * 10ms) - assert end_time - start_time < 0.1, "Concurrent operations not properly parallelized" + # Allow more tolerance for CI environments and system load + max_time = 0.15 # 150ms tolerance for CI environments + assert execution_time < max_time, f"Concurrent operations not properly parallelized: took {execution_time:.3f}s, expected < {max_time:.3f}s" assert len(results) == 100, "Not all operations completed" @pytest.mark.asyncio diff --git a/tests/security/test_configurable_headers.py b/tests/security/test_configurable_headers.py new file mode 100644 index 00000000..039c5809 --- /dev/null +++ b/tests/security/test_configurable_headers.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Configurable Security Headers Testing. + +This module tests the configurable security headers implementation for issue #533. +""" + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from unittest.mock import patch + +from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware +from mcpgateway.config import settings + + +def test_security_headers_can_be_disabled(): + """Test that security headers can be disabled via configuration.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'security_headers_enabled', False): + client = TestClient(app) + response = client.get("/test") + + # When disabled, security headers should not be present + assert "X-Content-Type-Options" not in response.headers + assert "X-Frame-Options" not in response.headers + assert "X-XSS-Protection" not in response.headers + assert "X-Download-Options" not in response.headers + + +def test_individual_headers_configurable(): + """Test that individual security headers can be configured.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Test with some headers disabled + with patch.multiple(settings, + security_headers_enabled=True, + x_content_type_options_enabled=False, + x_frame_options="SAMEORIGIN", + x_xss_protection_enabled=False, + x_download_options_enabled=True): + client = TestClient(app) + response = client.get("/test") + + # Check configured headers + assert "X-Content-Type-Options" not in response.headers # Disabled + assert response.headers["X-Frame-Options"] == "SAMEORIGIN" # Custom value + assert "X-XSS-Protection" not in response.headers # Disabled + assert response.headers["X-Download-Options"] == "noopen" # Enabled + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" # Always on + + +def test_hsts_configuration(): + """Test HSTS header configuration options.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Test with custom HSTS settings + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=True, + hsts_max_age=7776000, # 90 days + hsts_include_subdomains=False): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # Check HSTS configuration + assert "Strict-Transport-Security" in response.headers + hsts_value = response.headers["Strict-Transport-Security"] + assert "max-age=7776000" in hsts_value + assert "includeSubDomains" not in hsts_value # Disabled + + +def test_hsts_can_be_disabled(): + """Test that HSTS can be disabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=False): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # HSTS should not be present when disabled + assert "Strict-Transport-Security" not in response.headers + + +def test_server_header_removal_configurable(): + """Test that server header removal is configurable.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Test with server header removal disabled + with patch.multiple(settings, + security_headers_enabled=True, + remove_server_headers=False): + client = TestClient(app) + response = client.get("/test") + + # Server headers should not be removed when disabled + # Note: FastAPI/Starlette might not add these headers in test mode, + # but our middleware won't remove them if they exist + pass # This test mainly validates the configuration works + + +def test_all_headers_with_default_config(): + """Test all headers with default configuration.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Use default settings (all should be enabled) + client = TestClient(app) + response = client.get("/test") + + # All default headers should be present + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-XSS-Protection"] == "0" + assert response.headers["X-Download-Options"] == "noopen" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert "Content-Security-Policy" in response.headers diff --git a/tests/security/test_security_cookies.py b/tests/security/test_security_cookies.py new file mode 100644 index 00000000..5c7bebd9 --- /dev/null +++ b/tests/security/test_security_cookies.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Cookie Testing. + +This module contains tests for secure cookie configuration and handling. +""" + +import pytest +from fastapi import Response +from fastapi.testclient import TestClient +from unittest.mock import patch + +from mcpgateway.utils.security_cookies import ( + set_auth_cookie, + clear_auth_cookie, + set_session_cookie, + clear_session_cookie +) +from mcpgateway.config import settings + + +class TestSecureCookies: + """Test secure cookie configuration and attributes.""" + + def test_set_auth_cookie_development(self): + """Test auth cookie in development environment.""" + response = Response() + + with patch.object(settings, 'environment', 'development'): + with patch.object(settings, 'secure_cookies', False): + set_auth_cookie(response, "test_token", remember_me=False) + + # Check that cookie was set + set_cookie_header = response.headers.get("set-cookie", "") + assert "jwt_token=test_token" in set_cookie_header + assert "HttpOnly" in set_cookie_header + assert "SameSite=lax" in set_cookie_header + assert "Path=/" in set_cookie_header + assert "Max-Age=3600" in set_cookie_header # 1 hour + + # In development with secure_cookies=False, Secure flag should not be present + assert "Secure" not in set_cookie_header + + def test_set_auth_cookie_production(self): + """Test auth cookie in production environment.""" + response = Response() + + with patch.object(settings, 'environment', 'production'): + set_auth_cookie(response, "test_token", remember_me=False) + + set_cookie_header = response.headers.get("set-cookie", "") + assert "jwt_token=test_token" in set_cookie_header + assert "HttpOnly" in set_cookie_header + assert "Secure" in set_cookie_header # Should be secure in production + assert "SameSite=lax" in set_cookie_header + + def test_set_auth_cookie_remember_me(self): + """Test auth cookie with remember_me option.""" + response = Response() + + set_auth_cookie(response, "test_token", remember_me=True) + + set_cookie_header = response.headers.get("set-cookie", "") + # 30 days = 30 * 24 * 3600 = 2592000 seconds + assert "Max-Age=2592000" in set_cookie_header + + def test_set_auth_cookie_custom_samesite(self): + """Test auth cookie with custom SameSite setting.""" + response = Response() + + with patch.object(settings, 'cookie_samesite', 'strict'): + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + assert "SameSite=strict" in set_cookie_header + + def test_clear_auth_cookie(self): + """Test clearing auth cookie.""" + response = Response() + + clear_auth_cookie(response) + + set_cookie_header = response.headers.get("set-cookie", "") + assert "jwt_token=" in set_cookie_header # Empty value + assert "HttpOnly" in set_cookie_header + assert "Path=/" in set_cookie_header + + def test_set_session_cookie(self): + """Test setting session cookie.""" + response = Response() + + set_session_cookie(response, "session_123", max_age=7200) + + set_cookie_header = response.headers.get("set-cookie", "") + assert "session_id=session_123" in set_cookie_header + assert "HttpOnly" in set_cookie_header + assert "SameSite=lax" in set_cookie_header + assert "Max-Age=7200" in set_cookie_header + + def test_clear_session_cookie(self): + """Test clearing session cookie.""" + response = Response() + + clear_session_cookie(response) + + set_cookie_header = response.headers.get("set-cookie", "") + assert "session_id=" in set_cookie_header + assert "HttpOnly" in set_cookie_header + + def test_secure_flag_with_explicit_setting(self): + """Test secure flag behavior with explicit secure_cookies setting.""" + response = Response() + + # Test with secure_cookies=True in development + with patch.object(settings, 'environment', 'development'): + with patch.object(settings, 'secure_cookies', True): + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + assert "Secure" in set_cookie_header # Should be secure when explicitly enabled + + def test_cookie_attributes_consistency(self): + """Test that cookie attributes are consistent between set and clear operations.""" + response_set = Response() + response_clear = Response() + + with patch.object(settings, 'environment', 'production'): + with patch.object(settings, 'cookie_samesite', 'strict'): + set_auth_cookie(response_set, "test_token") + clear_auth_cookie(response_clear) + + set_header = response_set.headers.get("set-cookie", "") + clear_header = response_clear.headers.get("set-cookie", "") + + # Both should have same security attributes + for attr in ["HttpOnly", "Secure", "SameSite=strict", "Path=/"]: + assert attr in set_header + assert attr in clear_header + + +class TestCookieSecurityConfiguration: + """Test cookie security configuration under different scenarios.""" + + @pytest.mark.parametrize("environment,secure_cookies,expected_secure", [ + ("development", False, False), + ("development", True, True), + ("production", False, True), # Production always uses secure + ("production", True, True), + ]) + def test_secure_flag_combinations(self, environment: str, secure_cookies: bool, expected_secure: bool): + """Test secure flag under different environment and configuration combinations.""" + response = Response() + + with patch.object(settings, 'environment', environment): + with patch.object(settings, 'secure_cookies', secure_cookies): + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + + if expected_secure: + assert "Secure" in set_cookie_header + else: + assert "Secure" not in set_cookie_header + + @pytest.mark.parametrize("samesite_value", ["strict", "lax", "none"]) + def test_samesite_options(self, samesite_value: str): + """Test different SameSite options.""" + response = Response() + + with patch.object(settings, 'cookie_samesite', samesite_value): + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + assert f"SameSite={samesite_value}" in set_cookie_header + + def test_cookie_httponly_always_set(self): + """Test that HttpOnly is always set regardless of configuration.""" + response = Response() + + # Test in various configurations + configurations = [ + {"environment": "development", "secure_cookies": False}, + {"environment": "development", "secure_cookies": True}, + {"environment": "production", "secure_cookies": False}, + {"environment": "production", "secure_cookies": True}, + ] + + for config in configurations: + response = Response() # Fresh response for each test + with patch.multiple(settings, **config): + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + assert "HttpOnly" in set_cookie_header, f"HttpOnly missing in config: {config}" + + def test_cookie_path_always_set(self): + """Test that Path is always set to root.""" + response = Response() + + set_auth_cookie(response, "test_token") + + set_cookie_header = response.headers.get("set-cookie", "") + assert "Path=/" in set_cookie_header + + def test_multiple_cookies_do_not_interfere(self): + """Test that setting multiple different cookies doesn't interfere.""" + response = Response() + + set_auth_cookie(response, "auth_token") + set_session_cookie(response, "session_id", max_age=1800) + + # Response should have multiple set-cookie headers + set_cookie_headers = response.headers.getlist("set-cookie") + assert len(set_cookie_headers) == 2 + + # Check that both cookies are present + all_headers = " ".join(set_cookie_headers) + assert "jwt_token=auth_token" in all_headers + assert "session_id=session_id" in all_headers diff --git a/tests/security/test_security_headers.py b/tests/security/test_security_headers.py new file mode 100644 index 00000000..5f249c14 --- /dev/null +++ b/tests/security/test_security_headers.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Headers and CORS Testing. + +This module contains comprehensive tests for security headers middleware and CORS configuration. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch + +from mcpgateway.config import settings + + +class TestSecurityHeaders: + """Test security headers are properly set on all responses.""" + + def test_security_headers_present_on_health_endpoint(self, client: TestClient): + """Test that essential security headers are present on health endpoint.""" + response = client.get("/health") + + # Essential security headers + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-XSS-Protection"] == "0" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert "Content-Security-Policy" in response.headers + + # Verify CSP contains essential directives + csp = response.headers["Content-Security-Policy"] + assert "default-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + + def test_security_headers_present_on_api_endpoints(self, client: TestClient): + """Test security headers on API endpoints.""" + # Test with authentication disabled for this test + with patch.object(settings, 'auth_required', False): + response = client.get("/tools") + + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-XSS-Protection"] == "0" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert "Content-Security-Policy" in response.headers + + def test_sensitive_headers_removed(self, client: TestClient): + """Test that sensitive headers are removed.""" + response = client.get("/health") + + # These headers should not be present + assert "X-Powered-By" not in response.headers + assert "Server" not in response.headers + + def test_hsts_header_on_https_request(self, client: TestClient): + """Test HSTS header is present when X-Forwarded-Proto indicates HTTPS.""" + response = client.get("/health", headers={"X-Forwarded-Proto": "https"}) + + assert "Strict-Transport-Security" in response.headers + hsts_value = response.headers["Strict-Transport-Security"] + assert "max-age=31536000" in hsts_value + assert "includeSubDomains" in hsts_value + + def test_no_hsts_header_on_http_request(self, client: TestClient): + """Test HSTS header is not present on HTTP requests.""" + response = client.get("/health") + + # HSTS should not be present for HTTP requests + assert "Strict-Transport-Security" not in response.headers + + def test_content_security_policy_structure(self, client: TestClient): + """Test CSP header has proper structure and directives.""" + response = client.get("/health") + + csp = response.headers["Content-Security-Policy"] + + # Check for essential CSP directives + assert "default-src 'self'" in csp + assert "script-src 'self'" in csp + assert "style-src 'self'" in csp + assert "img-src 'self'" in csp + assert "font-src 'self'" in csp + assert "connect-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + + # Verify CSP ends with semicolon + assert csp.endswith(";") + + +class TestCORSConfiguration: + """Test CORS configuration and behavior.""" + + def test_cors_with_development_origins(self, client: TestClient): + """Test CORS works with development origins.""" + with patch.object(settings, 'environment', 'development'): + with patch.object(settings, 'allowed_origins', {'http://localhost:3000', 'http://localhost:8080'}): + # Test with actual GET request that includes CORS headers + response = client.get( + "/health", + headers={"Origin": "http://localhost:3000"} + ) + assert response.status_code == 200 + # Check that CORS headers are present for allowed origin + assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" + + def test_cors_blocks_unauthorized_origin(self, client: TestClient): + """Test CORS blocks unauthorized origins.""" + with patch.object(settings, 'allowed_origins', {'http://localhost:3000'}): + # Test blocked origin with GET request + response = client.get( + "/health", + headers={"Origin": "https://evil.com"} + ) + # For blocked origins, Access-Control-Allow-Origin should not be set to the blocked origin + assert response.headers.get("Access-Control-Allow-Origin") != "https://evil.com" + # The response should still succeed but without CORS headers for the blocked origin + assert response.status_code == 200 + + def test_cors_credentials_allowed(self, client: TestClient): + """Test CORS allows credentials when configured.""" + with patch.object(settings, 'cors_allow_credentials', True): + with patch.object(settings, 'allowed_origins', {'http://localhost:3000'}): + response = client.get( + "/health", + headers={"Origin": "http://localhost:3000"} + ) + assert response.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_cors_allowed_methods(self, client: TestClient): + """Test CORS exposes correct allowed methods.""" + with patch.object(settings, 'allowed_origins', {'http://localhost:3000'}): + # Test with an endpoint that supports OPTIONS for proper CORS preflight + # Use the root endpoint which should support more methods + response = client.get( + "/health", + headers={"Origin": "http://localhost:3000"} + ) + + # Check that the response includes CORS origin header indicating CORS is working + assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" + + def test_cors_exposed_headers(self, client: TestClient): + """Test CORS exposes correct headers.""" + with patch.object(settings, 'allowed_origins', {'http://localhost:3000'}): + response = client.get( + "/health", + headers={"Origin": "http://localhost:3000"} + ) + + # Check that CORS is working with the allowed origin + assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" + + # Check for exposed headers (these may be set by CORS middleware) + exposed_headers = response.headers.get("Access-Control-Expose-Headers", "") + if exposed_headers: # Only check if the header is present + assert "Content-Length" in exposed_headers + assert "X-Request-ID" in exposed_headers + + +class TestProductionSecurity: + """Test security configuration in production environment.""" + + def test_production_cors_requires_explicit_origins(self, client: TestClient): + """Test that production environment requires explicit CORS origins.""" + with patch.object(settings, 'environment', 'production'): + with patch.object(settings, 'allowed_origins', set()): + # Should have empty origins list for production without explicit config + assert len(settings.allowed_origins) == 0 + + def test_production_uses_https_origins(self, client: TestClient): + """Test that production environment uses HTTPS origins.""" + with patch.object(settings, 'environment', 'production'): + with patch.object(settings, 'app_domain', 'example.com'): + # This would be set during initialization + test_origins = { + "https://example.com", + "https://app.example.com", + "https://admin.example.com" + } + with patch.object(settings, 'allowed_origins', test_origins): + # All origins should be HTTPS + for origin in settings.allowed_origins: + assert origin.startswith("https://") + + def test_security_headers_consistent_across_endpoints(self, client: TestClient): + """Test security headers are consistent across different endpoints.""" + endpoints = ["/health", "/ready"] + + headers_to_check = [ + "X-Content-Type-Options", + "X-Frame-Options", + "X-XSS-Protection", + "Referrer-Policy", + "Content-Security-Policy" + ] + + responses = {} + for endpoint in endpoints: + responses[endpoint] = client.get(endpoint) + + # Check that all endpoints have the same security headers + for header in headers_to_check: + values = [responses[endpoint].headers.get(header) for endpoint in endpoints] + assert all(value == values[0] for value in values), f"Inconsistent {header} across endpoints" + + +class TestSecurityHeadersEdgeCases: + """Test edge cases and error conditions for security headers.""" + + def test_security_headers_on_error_responses(self, client: TestClient): + """Test security headers are present even on error responses.""" + # Make a request to a non-existent endpoint + response = client.get("/nonexistent") + + # Even 404 responses should have security headers + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert "Content-Security-Policy" in response.headers + + def test_security_headers_on_method_not_allowed(self, client: TestClient): + """Test security headers on 405 Method Not Allowed responses.""" + # Try to POST to a GET-only endpoint + response = client.post("/health") + + assert response.status_code == 405 + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert "Content-Security-Policy" in response.headers + + @pytest.mark.parametrize("forwarded_proto", ["http", "https", "invalid"]) + def test_hsts_with_various_forwarded_proto_values(self, client: TestClient, forwarded_proto: str): + """Test HSTS behavior with various X-Forwarded-Proto values.""" + response = client.get("/health", headers={"X-Forwarded-Proto": forwarded_proto}) + + if forwarded_proto == "https": + assert "Strict-Transport-Security" in response.headers + else: + assert "Strict-Transport-Security" not in response.headers + + +@pytest.fixture +def client(app): + """Create a test client for the FastAPI app.""" + return TestClient(app) diff --git a/tests/security/test_security_middleware_comprehensive.py b/tests/security/test_security_middleware_comprehensive.py new file mode 100644 index 00000000..589a8a8c --- /dev/null +++ b/tests/security/test_security_middleware_comprehensive.py @@ -0,0 +1,628 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Comprehensive Security Middleware Testing. + +This module provides comprehensive test coverage for the SecurityHeadersMiddleware +including all configuration combinations, edge cases, and integration scenarios. +""" + +import pytest +from fastapi import FastAPI, Response, Request +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock + +from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware +from mcpgateway.config import settings + + +class TestSecurityHeadersConfiguration: + """Test all security header configuration options.""" + + @pytest.mark.parametrize("enabled", [True, False]) + def test_security_headers_enabled_toggle(self, enabled: bool): + """Test security headers can be globally enabled/disabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'security_headers_enabled', enabled): + client = TestClient(app) + response = client.get("/test") + + if enabled: + # When enabled, headers should be present + assert "X-Content-Type-Options" in response.headers + assert "X-Frame-Options" in response.headers + assert "Content-Security-Policy" in response.headers + else: + # When disabled, no security headers should be added + assert "X-Content-Type-Options" not in response.headers + assert "X-Frame-Options" not in response.headers + assert "Content-Security-Policy" not in response.headers + + @pytest.mark.parametrize("x_content_enabled", [True, False]) + def test_x_content_type_options_configurable(self, x_content_enabled: bool): + """Test X-Content-Type-Options can be individually configured.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_content_type_options_enabled=x_content_enabled): + client = TestClient(app) + response = client.get("/test") + + if x_content_enabled: + assert response.headers["X-Content-Type-Options"] == "nosniff" + else: + assert "X-Content-Type-Options" not in response.headers + + @pytest.mark.parametrize("frame_option", ["DENY", "SAMEORIGIN", ""]) + def test_x_frame_options_configurable(self, frame_option: str): + """Test X-Frame-Options values are configurable.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_frame_options=frame_option): + client = TestClient(app) + response = client.get("/test") + + if frame_option: + assert response.headers["X-Frame-Options"] == frame_option + else: + assert "X-Frame-Options" not in response.headers + + @pytest.mark.parametrize("xss_enabled", [True, False]) + def test_x_xss_protection_configurable(self, xss_enabled: bool): + """Test X-XSS-Protection can be configured.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_xss_protection_enabled=xss_enabled): + client = TestClient(app) + response = client.get("/test") + + if xss_enabled: + assert response.headers["X-XSS-Protection"] == "0" + else: + assert "X-XSS-Protection" not in response.headers + + @pytest.mark.parametrize("download_enabled", [True, False]) + def test_x_download_options_configurable(self, download_enabled: bool): + """Test X-Download-Options can be configured.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_download_options_enabled=download_enabled): + client = TestClient(app) + response = client.get("/test") + + if download_enabled: + assert response.headers["X-Download-Options"] == "noopen" + else: + assert "X-Download-Options" not in response.headers + + def test_referrer_policy_always_set(self): + """Test Referrer-Policy is always set regardless of configuration.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'security_headers_enabled', True): + client = TestClient(app) + response = client.get("/test") + + # Referrer-Policy should always be set when headers are enabled + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + + +class TestHSTSConfiguration: + """Test HSTS header configuration options.""" + + @pytest.mark.parametrize("hsts_enabled", [True, False]) + def test_hsts_enabled_toggle(self, hsts_enabled: bool): + """Test HSTS can be enabled/disabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=hsts_enabled): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + if hsts_enabled: + assert "Strict-Transport-Security" in response.headers + else: + assert "Strict-Transport-Security" not in response.headers + + @pytest.mark.parametrize("max_age", [86400, 31536000, 63072000]) # 1 day, 1 year, 2 years + def test_hsts_max_age_configurable(self, max_age: int): + """Test HSTS max-age is configurable.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=True, + hsts_max_age=max_age, + hsts_include_subdomains=False): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + assert "Strict-Transport-Security" in response.headers + hsts_value = response.headers["Strict-Transport-Security"] + assert f"max-age={max_age}" in hsts_value + assert "includeSubDomains" not in hsts_value + + @pytest.mark.parametrize("include_subdomains", [True, False]) + def test_hsts_include_subdomains_configurable(self, include_subdomains: bool): + """Test HSTS includeSubDomains directive is configurable.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=True, + hsts_max_age=31536000, + hsts_include_subdomains=include_subdomains): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + hsts_value = response.headers["Strict-Transport-Security"] + if include_subdomains: + assert "includeSubDomains" in hsts_value + else: + assert "includeSubDomains" not in hsts_value + + @pytest.mark.parametrize("proto_header", ["https", "http", "invalid", None]) + def test_hsts_protocol_detection(self, proto_header: str): + """Test HSTS activation based on protocol detection.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=True): + client = TestClient(app) + headers = {} + if proto_header: + headers["X-Forwarded-Proto"] = proto_header + + response = client.get("/test", headers=headers) + + if proto_header == "https": + assert "Strict-Transport-Security" in response.headers + else: + assert "Strict-Transport-Security" not in response.headers + + +class TestServerHeaderRemoval: + """Test server header removal configuration.""" + + @pytest.mark.parametrize("remove_headers", [True, False]) + def test_server_header_removal_configurable(self, remove_headers: bool): + """Test server header removal can be configured.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + response = Response(content='{"message": "test"}', media_type="application/json") + # Simulate headers that might be set by the server + response.headers["X-Powered-By"] = "TestServer/1.0" + response.headers["Server"] = "TestServer/1.0" + return response + + with patch.multiple(settings, + security_headers_enabled=True, + remove_server_headers=remove_headers): + client = TestClient(app) + response = client.get("/test") + + # Note: In test mode, these headers might not be present initially + # This test mainly validates the configuration logic works + if remove_headers: + # Headers should be removed if they exist + assert "X-Powered-By" not in response.headers + assert "Server" not in response.headers + # If remove_headers=False, the middleware wouldn't remove them + # but in test mode they might not be present anyway + + +class TestCSPConfiguration: + """Test Content Security Policy configuration.""" + + def test_csp_always_present_when_headers_enabled(self): + """Test CSP is always present when security headers are enabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'security_headers_enabled', True): + client = TestClient(app) + response = client.get("/test") + + assert "Content-Security-Policy" in response.headers + csp = response.headers["Content-Security-Policy"] + + # Verify essential directives + assert "default-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + assert csp.endswith(";") + + def test_csp_includes_admin_ui_cdns(self): + """Test CSP includes all required CDN domains for Admin UI.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'security_headers_enabled', True): + client = TestClient(app) + response = client.get("/test") + + csp = response.headers["Content-Security-Policy"] + + # Check all required CDN domains are allowed + required_domains = [ + "https://cdnjs.cloudflare.com", + "https://cdn.tailwindcss.com", + "https://cdn.jsdelivr.net" + ] + + for domain in required_domains: + assert domain in csp, f"{domain} missing from CSP" + + +class TestMiddlewareIntegration: + """Test middleware integration with various response types.""" + + def test_security_headers_on_json_response(self): + """Test headers are added to JSON responses.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test", "data": [1, 2, 3]} + + client = TestClient(app) + response = client.get("/test") + + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + assert response.json() == {"message": "test", "data": [1, 2, 3]} + + def test_security_headers_on_html_response(self): + """Test headers are added to HTML responses.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return Response(content="Test", media_type="text/html") + + client = TestClient(app) + response = client.get("/test") + + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + assert "" in response.text + + def test_security_headers_on_different_status_codes(self): + """Test headers are added to responses with different status codes.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/success") + def success_endpoint(): + return {"message": "success"} + + @app.get("/not-found") + def not_found_endpoint(): + from fastapi import HTTPException + raise HTTPException(status_code=404, detail="Not found") + + client = TestClient(app) + + # Test successful response + response = client.get("/success") + assert response.status_code == 200 + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + + # Test 404 response + response = client.get("/not-found") + assert response.status_code == 404 + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + + def test_security_headers_preserve_existing_headers(self): + """Test middleware preserves existing response headers.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + response = Response(content='{"test": true}', media_type="application/json") + response.headers["Custom-Header"] = "custom-value" + response.headers["Cache-Control"] = "no-cache" + return response + + client = TestClient(app) + response = client.get("/test") + + # Existing headers should be preserved + assert response.headers["Custom-Header"] == "custom-value" + assert response.headers["Cache-Control"] == "no-cache" + + # Security headers should be added + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + + +class TestAllConfigurationCombinations: + """Test various combinations of security header configurations.""" + + def test_all_headers_disabled_except_csp(self): + """Test configuration with only CSP enabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_content_type_options_enabled=False, + x_frame_options="", # Empty means disabled + x_xss_protection_enabled=False, + x_download_options_enabled=False, + hsts_enabled=False, + remove_server_headers=False): + client = TestClient(app) + response = client.get("/test") + + # Only CSP and Referrer-Policy should be present + assert "X-Content-Type-Options" not in response.headers + assert "X-Frame-Options" not in response.headers + assert "X-XSS-Protection" not in response.headers + assert "X-Download-Options" not in response.headers + assert "Strict-Transport-Security" not in response.headers + + # These are always set when headers are enabled + assert "Content-Security-Policy" in response.headers + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + + def test_maximum_security_configuration(self): + """Test configuration with all security features enabled.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_content_type_options_enabled=True, + x_frame_options="DENY", + x_xss_protection_enabled=True, + x_download_options_enabled=True, + hsts_enabled=True, + hsts_max_age=63072000, # 2 years + hsts_include_subdomains=True, + remove_server_headers=True): + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # All headers should be present + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-XSS-Protection"] == "0" + assert response.headers["X-Download-Options"] == "noopen" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert "Content-Security-Policy" in response.headers + + # HSTS with custom settings + hsts_value = response.headers["Strict-Transport-Security"] + assert "max-age=63072000" in hsts_value + assert "includeSubDomains" in hsts_value + + +class TestMiddlewareErrorHandling: + """Test middleware behavior in error scenarios.""" + + def test_middleware_handles_none_response(self): + """Test middleware handles edge case responses gracefully.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Test with normal response + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + def test_middleware_with_request_variations(self): + """Test middleware with different request types.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/get-test") + def get_endpoint(): + return {"method": "GET"} + + @app.post("/post-test") + def post_endpoint(): + return {"method": "POST"} + + @app.put("/put-test") + def put_endpoint(): + return {"method": "PUT"} + + client = TestClient(app) + + # Test different HTTP methods all get security headers + for method, endpoint in [("GET", "/get-test"), ("POST", "/post-test"), ("PUT", "/put-test")]: + if method == "GET": + response = client.get(endpoint) + elif method == "POST": + response = client.post(endpoint) + elif method == "PUT": + response = client.put(endpoint) + + assert response.status_code == 200 + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "Content-Security-Policy" in response.headers + + +class TestProtocolDetection: + """Test various protocol detection scenarios for HSTS.""" + + @pytest.mark.parametrize("request_scheme,forwarded_proto,expect_hsts", [ + ("https", None, True), + ("http", "https", True), + ("https", "https", True), + ("http", "http", False), + ("http", None, False), + ("https", "http", True), # Request scheme takes precedence + ]) + def test_hsts_protocol_detection_combinations(self, request_scheme: str, forwarded_proto: str, expect_hsts: bool): + """Test HSTS activation under various protocol scenarios.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + hsts_enabled=True): + client = TestClient(app) + + # Mock the request URL scheme + headers = {} + if forwarded_proto: + headers["X-Forwarded-Proto"] = forwarded_proto + + # Note: TestClient always uses 'http' scheme, so we test forwarded proto + response = client.get("/test", headers=headers) + + if expect_hsts and forwarded_proto == "https": + assert "Strict-Transport-Security" in response.headers + else: + assert "Strict-Transport-Security" not in response.headers + + +class TestConfigurationValidation: + """Test configuration validation and edge cases.""" + + def test_empty_configuration_values(self): + """Test behavior with empty configuration values.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.multiple(settings, + security_headers_enabled=True, + x_frame_options="", # Empty string + hsts_max_age=0): # Zero value + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # Empty x_frame_options should result in no header + assert "X-Frame-Options" not in response.headers + + # Zero max-age should still work + if "Strict-Transport-Security" in response.headers: + assert "max-age=0" in response.headers["Strict-Transport-Security"] + + def test_settings_access_during_request(self): + """Test that settings are properly accessed during request processing.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Create a mock settings object to verify access patterns + with patch('mcpgateway.middleware.security_headers.settings') as mock_settings: + mock_settings.security_headers_enabled = True + mock_settings.x_content_type_options_enabled = True + mock_settings.x_frame_options = "DENY" + mock_settings.x_xss_protection_enabled = True + mock_settings.x_download_options_enabled = True + mock_settings.hsts_enabled = False + mock_settings.remove_server_headers = True + + client = TestClient(app) + response = client.get("/test") + + # Verify settings were accessed + assert mock_settings.security_headers_enabled + assert response.headers["X-Content-Type-Options"] == "nosniff" diff --git a/tests/security/test_security_performance_compatibility.py b/tests/security/test_security_performance_compatibility.py new file mode 100644 index 00000000..a11c6ea1 --- /dev/null +++ b/tests/security/test_security_performance_compatibility.py @@ -0,0 +1,605 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Performance and Compatibility Testing. + +This module tests the performance impact and browser/tool compatibility +of the security implementation. +""" + +import pytest +import time +from fastapi import FastAPI +from fastapi.testclient import TestClient +from unittest.mock import patch +import re + +from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware +from mcpgateway.config import settings + + +class TestPerformanceImpact: + """Test performance impact of security middleware.""" + + def test_middleware_overhead_minimal(self): + """Test security middleware has minimal performance overhead.""" + # App without security middleware + app_no_security = FastAPI() + + @app_no_security.get("/test") + def test_endpoint(): + return {"message": "test"} + + # App with security middleware + app_with_security = FastAPI() + app_with_security.add_middleware(SecurityHeadersMiddleware) + + @app_with_security.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Measure performance + iterations = 100 + + # Time without security + client_no_security = TestClient(app_no_security) + start_time = time.time() + for i in range(iterations): + response = client_no_security.get("/test") + assert response.status_code == 200 + time_without_security = time.time() - start_time + + # Time with security + client_with_security = TestClient(app_with_security) + start_time = time.time() + for i in range(iterations): + response = client_with_security.get("/test") + assert response.status_code == 200 + time_with_security = time.time() - start_time + + # Security overhead should be minimal (< 50% increase) + overhead_ratio = time_with_security / time_without_security + assert overhead_ratio < 1.5, f"Security middleware overhead too high: {overhead_ratio}x" + + def test_memory_usage_stable(self): + """Test security middleware doesn't cause memory leaks.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test", "data": list(range(100))} + + client = TestClient(app) + + # Make many requests to check for memory leaks + for i in range(200): + response = client.get(f"/test?iteration={i}") + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + # If we reach here without memory issues, test passes + assert True + + def test_large_response_performance(self): + """Test security middleware performance with large responses.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/large") + def large_endpoint(): + # Generate ~1MB response + large_data = {"data": ["x" * 1000] * 1000} + return large_data + + client = TestClient(app) + + start_time = time.time() + response = client.get("/large") + end_time = time.time() + + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + # Should complete within reasonable time (< 5 seconds) + processing_time = end_time - start_time + assert processing_time < 5.0, f"Large response too slow: {processing_time}s" + + +class TestBrowserCompatibility: + """Test security headers compatibility with different browsers.""" + + def test_csp_directive_format_compatibility(self): + """Test CSP directive format is browser-compatible.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + csp = response.headers["Content-Security-Policy"] + + # CSP should follow standard format + assert csp.endswith(";") # Should end with semicolon + assert "default-src 'self'" in csp + + # Directives should be properly formatted + directives = csp.split(";") + for directive in directives: + directive = directive.strip() + if directive: # Skip empty + # Should have directive-name followed by values + parts = directive.split(" ", 1) + assert len(parts) >= 1 + directive_name = parts[0] + assert re.match(r'^[a-z-]+$', directive_name), f"Invalid directive name: {directive_name}" + + def test_x_frame_options_standard_values(self): + """Test X-Frame-Options uses standard values.""" + standard_values = ["DENY", "SAMEORIGIN"] + + for value in standard_values: + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'x_frame_options', value): + client = TestClient(app) + response = client.get("/test") + + assert response.headers["X-Frame-Options"] == value + + def test_hsts_header_format_compliance(self): + """Test HSTS header format complies with RFC standards.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + if "Strict-Transport-Security" in response.headers: + hsts_value = response.headers["Strict-Transport-Security"] + + # Should match RFC format: max-age=; includeSubDomains + assert re.match(r'max-age=\d+(; includeSubDomains)?', hsts_value) + + def test_referrer_policy_standard_value(self): + """Test Referrer-Policy uses standard value.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + referrer_policy = response.headers["Referrer-Policy"] + + # Should be a standard referrer policy value + standard_policies = [ + "no-referrer", + "no-referrer-when-downgrade", + "origin", + "origin-when-cross-origin", + "same-origin", + "strict-origin", + "strict-origin-when-cross-origin", + "unsafe-url" + ] + + assert referrer_policy in standard_policies + + +class TestStaticAnalysisToolCompatibility: + """Test compatibility with static analysis tools.""" + + def test_csp_meta_tag_format(self): + """Test CSP meta tag format for static analysis tools.""" + # This tests the meta tag in admin.html indirectly + from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware + + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # HTTP header CSP should be well-formed for tools to parse + csp = response.headers["Content-Security-Policy"] + + # Should be parseable by security tools + assert "default-src" in csp + assert "'self'" in csp + assert "script-src" in csp + assert "frame-ancestors" in csp + + def test_security_headers_machine_readable(self): + """Test security headers are in machine-readable format.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Headers should be in standard format for automated tools + headers_to_check = { + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "0", + "X-Download-Options": "noopen" + } + + for header_name, expected_value in headers_to_check.items(): + assert response.headers[header_name] == expected_value + + def test_nodejsscan_detectable_patterns(self): + """Test patterns that nodejsscan and similar tools can detect.""" + # Test that our implementation includes patterns static analyzers expect + + # Test 1: CSP header presence + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Should have detectable security patterns + assert "Content-Security-Policy" in response.headers + csp = response.headers["Content-Security-Policy"] + + # Static analyzers look for these patterns + assert "default-src" in csp + assert "'self'" in csp + assert "script-src" in csp + + +class TestCORSPerformanceAndCompatibility: + """Test CORS performance and compatibility.""" + + def test_cors_origin_matching_performance(self): + """Test CORS origin matching doesn't impact performance.""" + from fastapi.middleware.cors import CORSMiddleware + + # Create app with many allowed origins + many_origins = [f"https://subdomain{i}.example.com" for i in range(100)] + + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=many_origins, + allow_credentials=True + ) + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + + # Test performance with many origins configured + start_time = time.time() + for i in range(20): + response = client.get("/test", headers={"Origin": f"https://subdomain{i}.example.com"}) + assert response.status_code == 200 + end_time = time.time() + + # Should complete quickly even with many origins + total_time = end_time - start_time + assert total_time < 2.0, f"CORS with many origins too slow: {total_time}s" + + def test_environment_aware_cors_switching(self): + """Test switching between environment CORS configurations.""" + # Test that environment switching works correctly + + # Development configuration + with patch.multiple(settings, + environment="development", + allowed_origins={"http://localhost:3000"}): + + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "dev"} + + client = TestClient(app) + response = client.get("/test") + + # Should work in development + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + # Production configuration + with patch.multiple(settings, + environment="production", + allowed_origins={"https://example.com"}): + + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "prod"} + + client = TestClient(app) + response = client.get("/test") + + # Should work in production + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + +class TestSecurityHeadersStandardsCompliance: + """Test security headers comply with web standards.""" + + def test_csp_level_2_compliance(self): + """Test CSP follows CSP Level 2 specification.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + csp = response.headers["Content-Security-Policy"] + + # CSP Level 2 directive compliance + required_directives = ["default-src", "script-src", "style-src"] + for directive in required_directives: + assert directive in csp + + # Should not use deprecated directives + deprecated_directives = ["script-src-elem", "script-src-attr"] + for directive in deprecated_directives: + assert directive not in csp + + def test_security_headers_case_sensitivity(self): + """Test security headers use correct case.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Headers should use standard case (HTTP headers are case-insensitive but have conventions) + expected_headers = [ + "X-Content-Type-Options", + "X-Frame-Options", + "X-XSS-Protection", + "X-Download-Options", + "Content-Security-Policy", + "Referrer-Policy" + ] + + for header in expected_headers: + assert header in response.headers, f"Missing header: {header}" + + def test_http_version_compatibility(self): + """Test security headers work with different HTTP versions.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + + # Test with different HTTP configurations + response = client.get("/test") + + # Should work with standard HTTP + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + + # Headers should be present regardless of HTTP version + assert response.headers["X-Content-Type-Options"] == "nosniff" + + +class TestContentTypeCompatibility: + """Test security headers with different content types.""" + + @pytest.mark.parametrize("content_type,content", [ + ("application/json", '{"test": "json"}'), + ("text/html", "Test"), + ("text/plain", "Plain text response"), + ("application/xml", "test"), + ("text/css", "body { color: black; }"), + ("application/javascript", "console.log('test');"), + ]) + def test_security_headers_with_content_types(self, content_type: str, content: str): + """Test security headers work with various content types.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + from fastapi import Response + return Response(content=content, media_type=content_type) + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + assert response.headers["Content-Type"].startswith(content_type) + + # Security headers should be present for all content types + assert "X-Content-Type-Options" in response.headers + assert "Content-Security-Policy" in response.headers + + # X-Download-Options is especially important for downloadable content + if content_type in ["application/octet-stream", "application/javascript"]: + assert "X-Download-Options" in response.headers + + def test_security_headers_with_binary_content(self): + """Test security headers work with binary content.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/binary") + def binary_endpoint(): + # Simulate binary content (like images, PDFs, etc.) + binary_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01' + from fastapi import Response + return Response(content=binary_data, media_type="image/png") + + client = TestClient(app) + response = client.get("/binary") + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "image/png" + + # Security headers should be present for binary content too + assert "X-Content-Type-Options" in response.headers + assert "X-Download-Options" in response.headers + assert "Content-Security-Policy" in response.headers + + +class TestSecurityInProxyScenarios: + """Test security implementation in proxy/load balancer scenarios.""" + + @pytest.mark.parametrize("proxy_headers", [ + {"X-Forwarded-Proto": "https", "X-Forwarded-Host": "example.com"}, + {"X-Forwarded-Proto": "http", "X-Forwarded-For": "192.168.1.1"}, + {"X-Real-IP": "10.0.0.1", "X-Forwarded-Proto": "https"}, + {"CF-Visitor": '{"scheme":"https"}', "X-Forwarded-Proto": "https"}, # Cloudflare + ]) + def test_hsts_with_proxy_headers(self, proxy_headers: dict): + """Test HSTS detection works with various proxy configurations.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + with patch.object(settings, 'hsts_enabled', True): + client = TestClient(app) + response = client.get("/test", headers=proxy_headers) + + if proxy_headers.get("X-Forwarded-Proto") == "https": + assert "Strict-Transport-Security" in response.headers + else: + assert "Strict-Transport-Security" not in response.headers + + def test_security_headers_with_load_balancer_headers(self): + """Test security headers work with common load balancer headers.""" + load_balancer_headers = { + "X-Forwarded-For": "192.168.1.1, 10.0.0.1", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "api.example.com", + "X-Request-ID": "req-12345", + "X-Correlation-ID": "corr-67890" + } + + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test", headers=load_balancer_headers) + + assert response.status_code == 200 + + # Security headers should be present + assert "X-Content-Type-Options" in response.headers + assert "Strict-Transport-Security" in response.headers # Due to X-Forwarded-Proto: https + + # Load balancer headers should be preserved + # Note: TestClient may not preserve all forwarded headers, but security should work + + +class TestConfigurationValidationAndErrors: + """Test configuration validation and error scenarios.""" + + def test_invalid_configuration_graceful_handling(self): + """Test graceful handling of invalid configuration values.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Test with potentially problematic configuration + with patch.multiple(settings, + security_headers_enabled=True, + x_frame_options="INVALID-VALUE", # Non-standard but should work + hsts_max_age=-1): # Negative value + client = TestClient(app) + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # Should not crash, though values might be non-standard + assert response.status_code == 200 + assert "X-Frame-Options" in response.headers + + # Non-standard values should be passed through + assert response.headers["X-Frame-Options"] == "INVALID-VALUE" + + def test_settings_attribute_access_safety(self): + """Test safe attribute access for settings.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Mock settings to test attribute access patterns + with patch('mcpgateway.middleware.security_headers.settings') as mock_settings: + # Configure mock with all expected attributes + mock_settings.security_headers_enabled = True + mock_settings.x_content_type_options_enabled = True + mock_settings.x_frame_options = "DENY" + mock_settings.x_xss_protection_enabled = True + mock_settings.x_download_options_enabled = True + mock_settings.hsts_enabled = True + mock_settings.hsts_max_age = 31536000 + mock_settings.hsts_include_subdomains = True + mock_settings.remove_server_headers = True + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + # If we reach here, attribute access was successful + assert "X-Content-Type-Options" in response.headers diff --git a/tests/security/test_standalone_middleware.py b/tests/security/test_standalone_middleware.py new file mode 100644 index 00000000..e008c088 --- /dev/null +++ b/tests/security/test_standalone_middleware.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Standalone Security Middleware Testing. + +This module tests the security middleware in isolation without the full app. +""" + +import pytest +from fastapi import FastAPI, Response +from fastapi.testclient import TestClient +from unittest.mock import patch + +from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware +from mcpgateway.config import settings + + +def test_security_headers_middleware_basic(): + """Test security headers middleware in isolation.""" + # Create a minimal FastAPI app + app = FastAPI() + + # Add the security headers middleware + app.add_middleware(SecurityHeadersMiddleware) + + # Add a simple endpoint + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Create test client + client = TestClient(app) + + # Make request + response = client.get("/test") + + # Check that security headers are present + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["X-Frame-Options"] == "DENY" + assert response.headers["X-XSS-Protection"] == "0" + assert response.headers["X-Download-Options"] == "noopen" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert "Content-Security-Policy" in response.headers + + # Check that sensitive headers are removed + assert "X-Powered-By" not in response.headers + assert "Server" not in response.headers + + +def test_security_headers_hsts_on_https(): + """Test HSTS header is added for HTTPS requests.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + + # Make request with HTTPS indication + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # Check HSTS header + assert "Strict-Transport-Security" in response.headers + assert "max-age=31536000" in response.headers["Strict-Transport-Security"] + assert "includeSubDomains" in response.headers["Strict-Transport-Security"] + + +def test_security_headers_no_hsts_on_http(): + """Test HSTS header is not added for HTTP requests.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + + # Make regular HTTP request + response = client.get("/test") + + # Check HSTS header is not present + assert "Strict-Transport-Security" not in response.headers + + +def test_csp_header_structure(): + """Test CSP header has correct structure.""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + csp = response.headers["Content-Security-Policy"] + + # Check for essential CSP directives + assert "default-src 'self'" in csp + assert "script-src 'self'" in csp + assert "style-src 'self'" in csp + assert "img-src 'self'" in csp + assert "font-src 'self'" in csp + assert "connect-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + + # Check that required CDN domains are allowed for Admin UI + assert "https://cdnjs.cloudflare.com" in csp + assert "https://cdn.tailwindcss.com" in csp + assert "https://cdn.jsdelivr.net" in csp + + # Verify CSP ends with semicolon + assert csp.endswith(";") diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 01ef618e..0b3a9083 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -1514,8 +1514,16 @@ async def test_admin_ui_cookie_settings(self, mock_roots, mock_gateways, mock_pr jwt_token = "test.jwt.token" response = await admin_ui(mock_request, False, mock_db, "admin", jwt_token) - # Verify cookie was set with correct parameters - mock_response.set_cookie.assert_called_once_with(key="jwt_token", value=jwt_token, httponly=True, secure=False, samesite="Strict") + # Verify cookie was set with secure parameters using security_cookies utility + mock_response.set_cookie.assert_called_once_with( + key="jwt_token", + value=jwt_token, + max_age=3600, # 1 hour + httponly=True, + secure=True, # Default secure_cookies=True + samesite="lax", # Default cookie_samesite + path="/" + ) class TestEdgeCasesAndErrorHandling: From 44aa14dca63e4d45de0d65247dfed7ca087ab3a6 Mon Sep 17 00:00:00 2001 From: VK <90204593+vk-playground@users.noreply.github.com> Date: Sun, 17 Aug 2025 07:16:43 -0400 Subject: [PATCH 09/21] feat: Bulk Import Tools modal wiring #737 (#739) * feat: Bulk Import Tools modal wiring and backend implementation - Add modal UI in admin.html with bulk import button and dialog - Implement modal open/close/ESC functionality in admin.js - Add POST /admin/tools/import endpoint with rate limiting - Support both JSON textarea and file upload inputs - Validate JSON structure and enforce 200 tool limit - Return detailed success/failure information per tool - Include loading states and comprehensive error handling Refs #737 Signed-off-by: Mihai Criveti * fix: Remove duplicate admin_import_tools function and fix HTML formatting - Remove duplicate admin_import_tools function definition - Fix HTML placeholder attribute to use double quotes - Add missing closing div tag - Fix flake8 blank line issues Signed-off-by: Mihai Criveti * feat: Complete bulk import backend with file upload support and enhanced docs - Add file upload support to admin_import_tools endpoint - Fix response format to match frontend expectations - Add UI usage documentation with modal instructions - Update API docs to show all three input methods - Enhance bulk import guide with UI and API examples Backend improvements: - Support tools_file form field for JSON file uploads - Proper file content parsing with error handling - Response includes imported/failed counts and details - Frontend-compatible response format for UI display Signed-off-by: Mihai Criveti * Bulk import Signed-off-by: Mihai Criveti * fix: Remove conflicting inline script and fix bulk import functionality - Remove conflicting inline JavaScript that was preventing form submission - Fix indentation in setupBulkImportModal function - Ensure bulk import modal uses proper admin.js implementation - Restore proper form submission handling for bulk import This fixes the issue where bulk import appeared to do nothing. Signed-off-by: Mihai Criveti * fix: Integrate bulk import setup with main initialization - Add setupBulkImportModal() to main initialization sequence - Remove duplicate DOMContentLoaded listener - Ensure bulk import doesn't interfere with other tab functionality Signed-off-by: Mihai Criveti * fix: JavaScript formatting issues in bulk import modal - Fix multiline querySelector formatting - Fix multiline Error constructor formatting - Ensure prettier compliance for web linting Signed-off-by: Mihai Criveti * debug: Temporarily disable bulk import setup to test tabs Signed-off-by: Mihai Criveti * fix: Remove duplicate setupFormValidation call and delay bulk import setup - Remove duplicate setupFormValidation() call that could cause conflicts - Use setTimeout to delay bulk import modal setup after other initialization - Add better null safety to form element queries - This should fix tab switching issues Signed-off-by: Mihai Criveti * fix: Restore proper initialization sequence for tab functionality - Remove setTimeout delay for bulk import setup - Keep bulk import setup in main initialization but with error handling - Ensure tab navigation isn't affected by bulk import modal setup Signed-off-by: Mihai Criveti * fix: Correct HTML structure and restore tab navigation - Move bulk import modal to correct location after tools panel - Remove extra closing div that was breaking HTML structure - Ensure proper page-level modal placement - Restore tab navigation functionality for all tabs This fixes the broken Global Resources, Prompts, Gateways, Roots, and Metrics tabs. Signed-off-by: Mihai Criveti * feat: Add configurable bulk import settings Configuration additions: - MCPGATEWAY_BULK_IMPORT_MAX_TOOLS (default: 200) - MCPGATEWAY_BULK_IMPORT_RATE_LIMIT (default: 10) Implementation: - config.py: Add new settings with defaults - admin.py: Use configurable rate limit and batch size - .env.example: Document all bulk import environment variables - admin.html: Use dynamic max tools value in UI text - CLAUDE.md: Document configuration options for developers - docs: Update bulk import guide with configuration details This makes bulk import fully configurable for different deployment scenarios. Signed-off-by: Mihai Criveti * Update docs Signed-off-by: Mihai Criveti --------- Signed-off-by: Mihai Criveti Co-authored-by: Mihai Criveti Signed-off-by: Shamsul Arefin --- .env.example | 6 + CLAUDE.md | 5 + docs/docs/manage/bulk-import.md | 67 ++++++++- mcpgateway/admin.py | 69 +++++++--- mcpgateway/config.py | 2 + mcpgateway/static/admin.js | 234 +++++++++++++++++++++++++++++++- mcpgateway/templates/admin.html | 75 +++++++++- 7 files changed, 426 insertions(+), 32 deletions(-) diff --git a/.env.example b/.env.example index 06a5c0c7..ed2a311c 100644 --- a/.env.example +++ b/.env.example @@ -131,6 +131,12 @@ MCPGATEWAY_ADMIN_API_ENABLED=true # Enable bulk import endpoint for tools (true/false) MCPGATEWAY_BULK_IMPORT_ENABLED=true +# Maximum number of tools allowed per bulk import request +MCPGATEWAY_BULK_IMPORT_MAX_TOOLS=200 + +# Rate limiting for bulk import endpoint (requests per minute) +MCPGATEWAY_BULK_IMPORT_RATE_LIMIT=10 + ##################################### # Header Passthrough Configuration ##################################### diff --git a/CLAUDE.md b/CLAUDE.md index eabefb8e..99d92ef6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -137,6 +137,11 @@ AUTH_REQUIRED=true MCPGATEWAY_UI_ENABLED=true MCPGATEWAY_ADMIN_API_ENABLED=true +# Bulk Import (Admin UI feature) +MCPGATEWAY_BULK_IMPORT_ENABLED=true # Enable/disable bulk import endpoint +MCPGATEWAY_BULK_IMPORT_MAX_TOOLS=200 # Maximum tools per import batch +MCPGATEWAY_BULK_IMPORT_RATE_LIMIT=10 # Requests per minute limit + # Federation MCPGATEWAY_ENABLE_MDNS_DISCOVERY=true MCPGATEWAY_ENABLE_FEDERATION=true diff --git a/docs/docs/manage/bulk-import.md b/docs/docs/manage/bulk-import.md index 17117d3b..6b4af7e7 100644 --- a/docs/docs/manage/bulk-import.md +++ b/docs/docs/manage/bulk-import.md @@ -2,34 +2,89 @@ The MCP Gateway provides a bulk import endpoint for efficiently loading multiple tools in a single request, perfect for migrations, environment setup, and team onboarding. -!!! info "Feature Flag Required" - This feature is controlled by the `MCPGATEWAY_BULK_IMPORT_ENABLED` environment variable. - Default: `true` (enabled). Set to `false` to disable this endpoint. +!!! info "Configuration Options" + This feature is controlled by several environment variables: + + - `MCPGATEWAY_BULK_IMPORT_ENABLED=true` - Enable/disable the endpoint (default: true) + - `MCPGATEWAY_BULK_IMPORT_MAX_TOOLS=200` - Maximum tools per batch (default: 200) + - `MCPGATEWAY_BULK_IMPORT_RATE_LIMIT=10` - Requests per minute limit (default: 10) --- ## 🚀 Overview -The `/admin/tools/import` endpoint allows you to register multiple tools at once, providing: +The bulk import feature allows you to register multiple tools at once through both the Admin UI and API, providing: - **Per-item validation** - One invalid tool won't fail the entire batch - **Detailed reporting** - Know exactly which tools succeeded or failed - **Rate limiting** - Protected against abuse (10 requests/minute) - **Batch size limits** - Maximum 200 tools per request -- **Multiple input formats** - JSON payload or form data +- **Multiple input formats** - JSON payload, form data, or file upload +- **User-friendly UI** - Modal dialog with drag-and-drop file support + +--- + +## 🎨 Admin UI Usage + +### Accessing the Bulk Import Modal + +1. **Navigate to Admin UI** - Open your gateway's admin interface at `http://localhost:4444/admin` +2. **Go to Tools Tab** - Click on the "Tools" tab in the main navigation +3. **Open Bulk Import** - Click the "+ Bulk Import Tools" button next to "Add New Tool" + +### Using the Modal + +The bulk import modal provides two ways to input tool data: + +#### Option 1: JSON Textarea +1. **Paste JSON directly** into the text area +2. **Validate format** - The modal will check JSON syntax before submission +3. **Click Import Tools** to process + +#### Option 2: File Upload +1. **Prepare a JSON file** with your tools array +2. **Click "Choose File"** and select your `.json` file +3. **Click Import Tools** to process + +### UI Features + +- **Real-time validation** - JSON syntax checking before submission +- **Loading indicators** - Progress spinner during import +- **Detailed results** - Success/failure counts with error details +- **Auto-refresh** - Page reloads automatically after successful import +- **Modal controls** - Close with button, backdrop click, or ESC key --- ## 📡 API Endpoint -### Request +### Request Methods +#### Method 1: JSON Body ```http POST /admin/tools/import Authorization: Bearer Content-Type: application/json ``` +#### Method 2: Form Data (JSON String) +```http +POST /admin/tools/import +Authorization: Bearer +Content-Type: multipart/form-data + +Form field: tools_json= +``` + +#### Method 3: File Upload +```http +POST /admin/tools/import +Authorization: Bearer +Content-Type: multipart/form-data + +Form field: tools_file= +``` + ### Payload Structure ```json diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 460b1b52..e1cad796 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -1581,6 +1581,7 @@ async def admin_ui( "root_path": root_path, "max_name_length": max_name_length, "gateway_tool_name_separator": settings.gateway_tool_name_separator, + "bulk_import_max_tools": settings.mcpgateway_bulk_import_max_tools, }, ) @@ -4423,7 +4424,7 @@ async def admin_list_tags( @admin_router.post("/tools/import/") @admin_router.post("/tools/import") -@rate_limit(requests_per_minute=10) +@rate_limit(requests_per_minute=settings.mcpgateway_bulk_import_rate_limit) async def admin_import_tools( request: Request, db: Session = Depends(get_db), @@ -4466,19 +4467,33 @@ async def admin_import_tools( except Exception as ex: LOGGER.exception("Invalid form body") return JSONResponse({"success": False, "message": f"Invalid form data: {ex}"}, status_code=422) - raw = form.get("tools_json") or form.get("json") or form.get("payload") - if not raw: - return JSONResponse({"success": False, "message": "Missing tools_json/json/payload form field."}, status_code=422) - try: - payload = json.loads(raw) - except Exception as ex: - LOGGER.exception("Invalid JSON in form field") - return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) + # Check for file upload first + if "tools_file" in form: + file = form["tools_file"] + if hasattr(file, "file"): + content = await file.read() + try: + payload = json.loads(content.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as ex: + LOGGER.exception("Invalid JSON file") + return JSONResponse({"success": False, "message": f"Invalid JSON file: {ex}"}, status_code=422) + else: + return JSONResponse({"success": False, "message": "Invalid file upload"}, status_code=422) + else: + # Check for JSON in form fields + raw = form.get("tools") or form.get("tools_json") or form.get("json") or form.get("payload") + if not raw: + return JSONResponse({"success": False, "message": "Missing tools/tools_json/json/payload form field."}, status_code=422) + try: + payload = json.loads(raw) + except Exception as ex: + LOGGER.exception("Invalid JSON in form field") + return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) if not isinstance(payload, list): return JSONResponse({"success": False, "message": "Payload must be a JSON array of tools."}, status_code=422) - max_batch = 200 + max_batch = settings.mcpgateway_bulk_import_max_tools if len(payload) > max_batch: return JSONResponse({"success": False, "message": f"Too many tools ({len(payload)}). Max {max_batch}."}, status_code=413) @@ -4511,15 +4526,33 @@ async def admin_import_tools( LOGGER.exception("Unexpected error importing tool %r at index %d", name, i) errors.append({"index": i, "name": name, "error": {"message": str(ex)}}) - return JSONResponse( - { - "success": len(errors) == 0, - "created_count": len(created), - "failed_count": len(errors), - "created": created, - "errors": errors, + # Format response to match both frontend and test expectations + response_data = { + "success": len(errors) == 0, + # New format for frontend + "imported": len(created), + "failed": len(errors), + "total": len(payload), + # Original format for tests + "created_count": len(created), + "failed_count": len(errors), + "created": created, + "errors": errors, + # Detailed format for frontend + "details": { + "success": [item["name"] for item in created if item.get("name")], + "failed": [{"name": item["name"], "error": item["error"].get("message", str(item["error"]))} for item in errors], }, - status_code=200, + } + + if len(errors) == 0: + response_data["message"] = f"Successfully imported all {len(created)} tools" + else: + response_data["message"] = f"Imported {len(created)} of {len(payload)} tools. {len(errors)} failed." + + return JSONResponse( + response_data, + status_code=200, # Always return 200, success field indicates if all succeeded ) except HTTPException: diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 916dfb8f..be7c7de6 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -154,6 +154,8 @@ class Settings(BaseSettings): mcpgateway_ui_enabled: bool = False mcpgateway_admin_api_enabled: bool = False mcpgateway_bulk_import_enabled: bool = True + mcpgateway_bulk_import_max_tools: int = 200 + mcpgateway_bulk_import_rate_limit: int = 10 # Security skip_ssl_verify: bool = False diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index dd49de43..c5bb804b 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -5977,6 +5977,16 @@ document.addEventListener("DOMContentLoaded", () => { // 4. Handle initial tab/state initializeTabState(); + // 5. Set up form validation + setupFormValidation(); + + // 6. Setup bulk import modal + try { + setupBulkImportModal(); + } catch (error) { + console.error("Error setting up bulk import modal:", error); + } + // // ✅ 4.1 Set up tab button click handlers // document.querySelectorAll('.tab-button').forEach(button => { // button.addEventListener('click', () => { @@ -5990,9 +6000,6 @@ document.addEventListener("DOMContentLoaded", () => { // }); // }); - // 5. Set up form validation - setupFormValidation(); - // Mark as initialized AppState.isInitialized = true; @@ -7075,3 +7082,224 @@ async function fetchToolsForGateway(gatewayId, gatewayName) { window.fetchToolsForGateway = fetchToolsForGateway; console.log("🛡️ ContextForge MCP Gateway admin.js initialized"); + +// =================================================================== +// BULK IMPORT TOOLS — MODAL WIRING +// =================================================================== + +function setupBulkImportModal() { + const openBtn = safeGetElement("open-bulk-import", true); + const modalId = "bulk-import-modal"; + const modal = safeGetElement(modalId, true); + + if (!openBtn || !modal) { + console.warn( + "Bulk Import modal wiring skipped (missing button or modal).", + ); + return; + } + + // avoid double-binding if admin.js gets evaluated more than once + if (openBtn.dataset.wired === "1") { + return; + } + openBtn.dataset.wired = "1"; + + const closeBtn = safeGetElement("close-bulk-import", true); + const backdrop = safeGetElement("bulk-import-backdrop", true); + const resultEl = safeGetElement("import-result", true); + + const focusTarget = + modal?.querySelector("#tools_json") || + modal?.querySelector("#tools_file") || + modal?.querySelector("[data-autofocus]"); + + // helpers + const open = (e) => { + if (e) { + e.preventDefault(); + } + // clear previous results each time we open + if (resultEl) { + resultEl.innerHTML = ""; + } + openModal(modalId); + // prevent background scroll + document.documentElement.classList.add("overflow-hidden"); + document.body.classList.add("overflow-hidden"); + if (focusTarget) { + setTimeout(() => focusTarget.focus(), 0); + } + return false; + }; + + const close = () => { + // also clear results on close to keep things tidy + closeModal(modalId, "import-result"); + document.documentElement.classList.remove("overflow-hidden"); + document.body.classList.remove("overflow-hidden"); + }; + + // wire events + openBtn.addEventListener("click", open); + + if (closeBtn) { + closeBtn.addEventListener("click", (e) => { + e.preventDefault(); + close(); + }); + } + + // click on backdrop only (not the dialog content) closes the modal + if (backdrop) { + backdrop.addEventListener("click", (e) => { + if (e.target === backdrop) { + close(); + } + }); + } + + // ESC to close + modal.addEventListener("keydown", (e) => { + if (e.key === "Escape") { + e.stopPropagation(); + close(); + } + }); + + // FORM SUBMISSION → handle bulk import + const form = safeGetElement("bulk-import-form", true); + if (form) { + form.addEventListener("submit", async (e) => { + e.preventDefault(); + e.stopPropagation(); + const resultEl = safeGetElement("import-result", true); + const indicator = safeGetElement("bulk-import-indicator", true); + + try { + const formData = new FormData(); + + // Get JSON from textarea or file + const jsonTextarea = form?.querySelector('[name="tools_json"]'); + const fileInput = form?.querySelector('[name="tools_file"]'); + + let hasData = false; + + // Check for file upload first (takes precedence) + if (fileInput && fileInput.files.length > 0) { + formData.append("tools_file", fileInput.files[0]); + hasData = true; + } else if (jsonTextarea && jsonTextarea.value.trim()) { + // Validate JSON before sending + try { + const toolsData = JSON.parse(jsonTextarea.value); + if (!Array.isArray(toolsData)) { + throw new Error("JSON must be an array of tools"); + } + formData.append("tools", jsonTextarea.value); + hasData = true; + } catch (err) { + if (resultEl) { + resultEl.innerHTML = ` +
+

Invalid JSON

+

${escapeHtml(err.message)}

+
+ `; + } + return; + } + } + + if (!hasData) { + if (resultEl) { + resultEl.innerHTML = ` +
+

Please provide JSON data or upload a file

+
+ `; + } + return; + } + + // Show loading state + if (indicator) { + indicator.style.display = "flex"; + } + + // Submit to backend + const response = await fetchWithTimeout( + `${window.ROOT_PATH}/admin/tools/import`, + { + method: "POST", + body: formData, + }, + ); + + const result = await response.json(); + + // Display results + if (resultEl) { + if (result.success) { + resultEl.innerHTML = ` +
+

Import Successful

+

${escapeHtml(result.message)}

+
+ `; + + // Close modal and refresh page after delay + setTimeout(() => { + closeModal("bulk-import-modal"); + window.location.reload(); + }, 2000); + } else if (result.imported > 0) { + // Partial success + let detailsHtml = ""; + if (result.details && result.details.failed) { + detailsHtml = + '
    '; + result.details.failed.forEach((item) => { + detailsHtml += `
  • ${escapeHtml(item.name)}: ${escapeHtml(item.error)}
  • `; + }); + detailsHtml += "
"; + } + + resultEl.innerHTML = ` +
+

Partial Import

+

${escapeHtml(result.message)}

+ ${detailsHtml} +
+ `; + } else { + // Complete failure + resultEl.innerHTML = ` +
+

Import Failed

+

${escapeHtml(result.message)}

+
+ `; + } + } + } catch (error) { + console.error("Bulk import error:", error); + if (resultEl) { + resultEl.innerHTML = ` +
+

Import Error

+

${escapeHtml(error.message || "An unexpected error occurred")}

+
+ `; + } + } finally { + // Hide loading state + if (indicator) { + indicator.style.display = "none"; + } + } + + return false; + }); + } +} diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index f6ce3d7b..f2899a69 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -47,6 +47,11 @@ rel="stylesheet" /> + @@ -1078,11 +1083,20 @@

-
-

- Add New Tool -

+
+

Add New Tool

+ + +
+ +
@@ -1312,6 +1326,58 @@

+ + +