diff --git a/.env.example b/.env.example
index 3f94d0b94..26806fa59 100644
--- a/.env.example
+++ b/.env.example
@@ -18,12 +18,41 @@ CACHE_TYPE=memory
# CACHE_TYPE=redis
# REDIS_URL=redis://localhost:6379/0
+
+# Maximum number of times to boot redis connection for cold start
+REDIS_MAX_RETRIES=3
+
+# Interval time for next retry of redis connection
+REDIS_RETRY_INTERVAL_MS=2000
+
+#####################################
+# Protocol Settings
+#####################################
+
+# MCP protocol version supported by this gateway
+PROTOCOL_VERSION=2025-03-26
+
+#####################################
+# Authentication
+#####################################
+
+# Admin UI basic-auth credentials
+# PRODUCTION: Change these to strong, unique values!
# Authentication Configuration
JWT_SECRET_KEY=my-test-key
JWT_ALGORITHM=HS256
BASIC_AUTH_USER=admin
BASIC_AUTH_PASSWORD=changeme
AUTH_REQUIRED=true
+
+# Secret used to sign JWTs (use long random value in prod)
+# PRODUCTION: Use a strong, random secret (minimum 32 characters)
+JWT_SECRET_KEY=my-test-key
+
+# Algorithm used to sign JWTs (e.g., HS256)
+JWT_ALGORITHM=HS256
+
+# Expiry time for generated JWT tokens (in minutes; e.g. 7 days)
TOKEN_EXPIRY=10080
REQUIRE_TOKEN_EXPIRATION=false
@@ -32,10 +61,134 @@ MCP_CLIENT_AUTH_ENABLED=true
TRUST_PROXY_AUTH=false
PROXY_USER_HEADER=X-Authenticated-User
+# Used to derive an AES encryption key for secure auth storage
+# Must be a non-empty string (e.g. passphrase or random secret)
+AUTH_ENCRYPTION_SECRET=my-test-salt
+
+#####################################
+# Admin UI and API Toggles
+#####################################
+
+# Enable the visual Admin UI (true/false)
+# PRODUCTION: Set to false for security
+MCPGATEWAY_UI_ENABLED=true
+
+# Enable the Admin API endpoints (true/false)
+# PRODUCTION: Set to false for security
+
# UI/Admin Feature Flags
MCPGATEWAY_UI_ENABLED=true
MCPGATEWAY_ADMIN_API_ENABLED=true
MCPGATEWAY_BULK_IMPORT_ENABLED=true
+
+# Maximum number of tools allowed per bulk import request
+MCPGATEWAY_BULK_IMPORT_MAX_TOOLS=200
+
+# Rate limiting for bulk import endpoint (requests per minute)
+MCPGATEWAY_BULK_IMPORT_RATE_LIMIT=10
+
+#####################################
+# Header Passthrough Configuration
+#####################################
+
+# SECURITY WARNING: Header passthrough is disabled by default for security.
+# Only enable if you understand the security implications and have reviewed
+# which headers should be passed through to backing MCP servers.
+# ENABLE_HEADER_PASSTHROUGH=false
+
+# Default headers to pass through (when feature is enabled)
+# JSON array format recommended: ["X-Tenant-Id", "X-Trace-Id"]
+# Comma-separated also supported: X-Tenant-Id,X-Trace-Id
+# NOTE: Authorization header removed from defaults for security
+# DEFAULT_PASSTHROUGH_HEADERS=["X-Tenant-Id", "X-Trace-Id"]
+
+#####################################
+# Security and CORS
+#####################################
+
+# Skip TLS certificate checks for upstream requests (not recommended in prod)
+SKIP_SSL_VERIFY=false
+
+# CORS origin allowlist (use JSON array of URLs)
+# Example: ["http://localhost:3000"]
+# Do not quote this value. Start with [] to ensure it's valid JSON.
+ALLOWED_ORIGINS='["http://localhost", "http://localhost:4444"]'
+
+# Enable CORS handling in the gateway
+CORS_ENABLED=true
+
+# CORS allow credentials (true/false)
+CORS_ALLOW_CREDENTIALS=true
+
+# Environment setting (development/production) - affects security defaults
+# development: Auto-configures CORS for localhost:3000, localhost:8080, etc.
+# production: Uses APP_DOMAIN for HTTPS origins, enforces secure cookies
+ENVIRONMENT=development
+
+# Domain configuration for production CORS origins
+# In production, automatically creates origins: https://APP_DOMAIN, https://app.APP_DOMAIN, https://admin.APP_DOMAIN
+# For production: set to your actual domain (e.g., mycompany.com)
+APP_DOMAIN=localhost
+
+# Security settings for cookies
+# production: Automatically enables secure cookies regardless of this setting
+# development: Set to false for HTTP development, true for HTTPS
+SECURE_COOKIES=true
+
+# Cookie SameSite attribute for CSRF protection
+# strict: Maximum security, may break some OAuth flows
+# lax: Good balance of security and compatibility (recommended)
+# none: Requires Secure=true, allows cross-site usage
+COOKIE_SAMESITE=lax
+
+#####################################
+# Security Headers Configuration
+#####################################
+
+# Enable security headers middleware (true/false)
+SECURITY_HEADERS_ENABLED=true
+
+# X-Frame-Options setting (DENY, SAMEORIGIN, or ALLOW-FROM uri)
+# DENY: Prevents all iframe embedding (recommended for security)
+# SAMEORIGIN: Allows embedding from same domain only
+# To disable: Set to empty string X_FRAME_OPTIONS=""
+X_FRAME_OPTIONS=DENY
+
+# Other security headers (true/false)
+X_CONTENT_TYPE_OPTIONS_ENABLED=true
+X_XSS_PROTECTION_ENABLED=true
+X_DOWNLOAD_OPTIONS_ENABLED=true
+
+# HSTS (HTTP Strict Transport Security) settings
+HSTS_ENABLED=true
+# HSTS max age in seconds (31536000 = 1 year)
+HSTS_MAX_AGE=31536000
+HSTS_INCLUDE_SUBDOMAINS=true
+
+# Remove server identification headers (true/false)
+REMOVE_SERVER_HEADERS=true
+
+# Enable HTTP Basic Auth for docs endpoints (in addition to Bearer token auth)
+# Uses the same credentials as BASIC_AUTH_USER and BASIC_AUTH_PASSWORD
+DOCS_ALLOW_BASIC_AUTH=false
+
+#####################################
+# Retry Config for HTTP Requests
+#####################################
+
+RETRY_MAX_ATTEMPTS=3
+# seconds
+RETRY_BASE_DELAY=1.0
+# seconds
+RETRY_MAX_DELAY=60.0
+# fraction of delay
+RETRY_JITTER_MAX=0.5
+
+#####################################
+# Logging
+#####################################
+
+# Logging verbosity level: DEBUG, INFO, WARNING, ERROR, CRITICAL
MCPGATEWAY_BULK_IMPORT_MAX_TOOLS=200
MCPGATEWAY_BULK_IMPORT_RATE_LIMIT=10
@@ -132,8 +285,8 @@ WELL_KNOWN_SECURITY_TXT=""
# Example: {"ai.txt": "AI Usage: This service uses AI for tool orchestration...", "dnt-policy.txt": "We respect DNT headers..."}
WELL_KNOWN_CUSTOM_FILES="{}"
-# Cache control for well-known files (seconds)
-WELL_KNOWN_CACHE_MAX_AGE=3600 # 1 hour
+# Cache control for well-known files (seconds) - 3600 = 1 hour
+WELL_KNOWN_CACHE_MAX_AGE=3600
#####################################
# Well-Known URI Examples
diff --git a/docs/docs/architecture/oauth-authorization-code-ui-design.md b/docs/docs/architecture/oauth-authorization-code-ui-design.md
new file mode 100644
index 000000000..2e08b6de2
--- /dev/null
+++ b/docs/docs/architecture/oauth-authorization-code-ui-design.md
@@ -0,0 +1,824 @@
+# OAuth 2.0 Authorization Code Flow UI Implementation Design
+
+**Version**: 1.0
+**Status**: Design Document
+**Date**: December 2024
+**Related**: [OAuth Design Document](./oauth-design.md)
+
+## Executive Summary
+
+This document outlines the design for implementing OAuth 2.0 Authorization Code flow with user consent in the MCP Gateway UI. The implementation will extend the existing OAuth infrastructure to support user delegation flows, token storage, and automatic token refresh, enabling agents to act on behalf of users with proper consent and scoped permissions.
+
+## Current State Analysis
+
+### Existing Implementation
+- ✅ OAuth Manager service with Client Credentials flow
+- ✅ Basic Authorization Code flow support in OAuth Manager
+- ✅ OAuth configuration fields in Gateway creation UI
+- ✅ OAuth callback endpoint (`/oauth/callback`)
+- ✅ Database schema with `oauth_config` JSON field
+- ✅ Client secret encryption/decryption
+
+### Current Limitations
+- ❌ No token storage mechanism for Authorization Code flow
+- ❌ No refresh token handling
+- ❌ Incomplete UI flow for user consent
+- ❌ No token expiration management
+- ❌ Limited error handling for OAuth flows
+
+## Architecture Overview
+
+```mermaid
+graph TD
+ subgraph "MCP Gateway UI"
+ A[Gateway Configuration]
+ B[OAuth Authorization Flow]
+ C[Token Management]
+ D[User Consent Interface]
+ end
+
+ subgraph "Backend Services"
+ E[OAuth Manager]
+ F[Token Storage Service]
+ G[Gateway Service]
+ end
+
+ subgraph "Database"
+ H[Gateway Table]
+ I[OAuth Tokens Table]
+ end
+
+ subgraph "External"
+ J[OAuth Provider]
+ K[User Browser]
+ end
+
+ A --> E
+ B --> E
+ E --> F
+ F --> I
+ G --> F
+ B --> K
+ K --> J
+ J --> B
+ E --> J
+```
+
+## Database Schema Changes
+
+### New OAuth Tokens Table
+
+```sql
+CREATE TABLE oauth_tokens (
+ id VARCHAR(36) PRIMARY KEY DEFAULT (uuid()),
+ gateway_id VARCHAR(36) NOT NULL,
+ user_id VARCHAR(255) NOT NULL, -- OAuth provider user ID
+ access_token TEXT NOT NULL,
+ refresh_token TEXT,
+ token_type VARCHAR(50) DEFAULT 'Bearer',
+ expires_at TIMESTAMP,
+ scopes JSON,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+
+ FOREIGN KEY (gateway_id) REFERENCES gateways(id) ON DELETE CASCADE,
+ UNIQUE KEY unique_gateway_user (gateway_id, user_id)
+);
+
+-- Index for efficient token lookup
+CREATE INDEX idx_oauth_tokens_gateway_user ON oauth_tokens(gateway_id, user_id);
+CREATE INDEX idx_oauth_tokens_expires ON oauth_tokens(expires_at);
+```
+
+### Modified Gateway Table
+
+```sql
+-- Add new fields to existing oauth_config JSON structure
+ALTER TABLE gateways
+MODIFY COLUMN oauth_config JSON COMMENT 'OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, scopes, and token management settings';
+
+-- Updated oauth_config structure:
+{
+ "grant_type": "client_credentials|authorization_code",
+ "client_id": "string",
+ "client_secret": "encrypted_string",
+ "authorization_url": "string",
+ "token_url": "string",
+ "redirect_uri": "string",
+ "scopes": ["scope1", "scope2"],
+ "token_management": {
+ "store_tokens": true,
+ "auto_refresh": true,
+ "refresh_threshold_seconds": 300
+ }
+}
+```
+
+## Core Components
+
+### 1. Token Storage Service
+
+**Location**: `mcpgateway/services/token_storage_service.py`
+
+```python
+from datetime import datetime, timedelta
+from typing import Optional, Dict, Any, List
+from sqlalchemy.orm import Session
+from mcpgateway.db import OAuthToken, DbGateway
+from mcpgateway.utils.oauth_encryption import get_oauth_encryption
+
+class TokenStorageService:
+ """Manages OAuth token storage and retrieval."""
+
+ def __init__(self, db: Session):
+ self.db = db
+ self.encryption = get_oauth_encryption()
+
+ async def store_tokens(
+ self,
+ gateway_id: str,
+ user_id: str,
+ access_token: str,
+ refresh_token: Optional[str],
+ expires_in: int,
+ scopes: List[str]
+ ) -> OAuthToken:
+ """Store OAuth tokens for a gateway-user combination."""
+
+ # Encrypt sensitive tokens
+ encrypted_access = self.encryption.encrypt_secret(access_token)
+ encrypted_refresh = None
+ if refresh_token:
+ encrypted_refresh = self.encryption.encrypt_secret(refresh_token)
+
+ # Calculate expiration
+ expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
+
+ # Create or update token record
+ token_record = self.db.query(OAuthToken).filter(
+ OAuthToken.gateway_id == gateway_id,
+ OAuthToken.user_id == user_id
+ ).first()
+
+ if token_record:
+ # Update existing record
+ token_record.access_token = encrypted_access
+ token_record.refresh_token = encrypted_refresh
+ token_record.expires_at = expires_at
+ token_record.scopes = scopes
+ token_record.updated_at = datetime.utcnow()
+ else:
+ # Create new record
+ token_record = OAuthToken(
+ gateway_id=gateway_id,
+ user_id=user_id,
+ access_token=encrypted_access,
+ refresh_token=encrypted_refresh,
+ expires_at=expires_at,
+ scopes=scopes
+ )
+ self.db.add(token_record)
+
+ self.db.commit()
+ return token_record
+```
+
+ async def get_valid_token(
+ self,
+ gateway_id: str,
+ user_id: str
+ ) -> Optional[str]:
+ """Get a valid access token, refreshing if necessary."""
+
+ token_record = self.db.query(OAuthToken).filter(
+ OAuthToken.gateway_id == gateway_id,
+ OAuthToken.user_id == user_id
+ ).first()
+
+ if not token_record:
+ return None
+
+ # Check if token is expired or near expiration
+ if self._is_token_expired(token_record):
+ if token_record.refresh_token:
+ # Attempt to refresh token
+ new_token = await self._refresh_access_token(token_record)
+ if new_token:
+ return new_token
+ return None
+
+ # Decrypt and return valid token
+ return self.encryption.decrypt_secret(token_record.access_token)
+
+ async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]:
+ """Refresh an expired access token using refresh token."""
+ # Implementation for token refresh
+ pass
+
+ def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool:
+ """Check if token is expired or near expiration."""
+ return datetime.utcnow() + timedelta(seconds=threshold_seconds) >= token_record.expires_at
+```
+
+### 2. Enhanced OAuth Manager
+
+**Location**: `mcpgateway/services/oauth_manager.py`
+
+```python
+class OAuthManager:
+ """Enhanced OAuth Manager with token storage support."""
+
+ def __init__(self, token_storage: TokenStorageService):
+ self.token_storage = token_storage
+
+ async def initiate_authorization_code_flow(
+ self,
+ gateway_id: str,
+ credentials: Dict[str, Any]
+ ) -> Dict[str, str]:
+ """Initiate Authorization Code flow and return authorization URL."""
+
+ # Generate state parameter for CSRF protection
+ state = self._generate_state(gateway_id)
+
+ # Store state in session/cache for validation
+ await self._store_authorization_state(gateway_id, state)
+
+ # Generate authorization URL
+ auth_url, _ = self._create_authorization_url(credentials, state)
+
+ return {
+ 'authorization_url': auth_url,
+ 'state': state,
+ 'gateway_id': gateway_id
+ }
+```
+
+ async def complete_authorization_code_flow(
+ self,
+ gateway_id: str,
+ code: str,
+ state: str,
+ credentials: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Complete Authorization Code flow and store tokens."""
+
+ # Validate state parameter
+ if not await self._validate_authorization_state(gateway_id, state):
+ raise OAuthError("Invalid state parameter")
+
+ # Exchange code for tokens
+ token_response = await self._exchange_code_for_tokens(credentials, code)
+
+ # Extract user information from token response
+ user_id = self._extract_user_id(token_response, credentials)
+
+ # Store tokens
+ token_record = await self.token_storage.store_tokens(
+ gateway_id=gateway_id,
+ user_id=user_id,
+ access_token=token_response['access_token'],
+ refresh_token=token_response.get('refresh_token'),
+ expires_in=token_response.get('expires_in', 3600),
+ scopes=token_response.get('scope', '').split()
+ )
+
+ return {
+ 'success': True,
+ 'user_id': user_id,
+ 'expires_at': token_record.expires_at.isoformat()
+ }
+
+ async def get_access_token_for_user(
+ self,
+ gateway_id: str,
+ user_id: str
+ ) -> Optional[str]:
+ """Get valid access token for a specific user."""
+ return await self.token_storage.get_valid_token(gateway_id, user_id)
+```
+
+### 3. OAuth Callback Handler
+
+**Location**: `mcpgateway/routers/oauth_router.py`
+
+```python
+from fastapi import APIRouter, Depends, Request, HTTPException
+from fastapi.responses import RedirectResponse, HTMLResponse
+
+oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
+
+@oauth_router.get("/authorize/{gateway_id}")
+async def initiate_oauth_flow(
+ gateway_id: str,
+ request: Request,
+ db: Session = Depends(get_db)
+) -> RedirectResponse:
+ """Initiate OAuth Authorization Code flow."""
+
+ # Get gateway configuration
+ gateway = db.query(DbGateway).filter(DbGateway.id == gateway_id).first()
+ if not gateway or not gateway.oauth_config:
+ raise HTTPException(status_code=404, detail="Gateway not found or not configured for OAuth")
+
+ # Initiate OAuth flow
+ oauth_manager = OAuthManager(TokenStorageService(db))
+ auth_data = await oauth_manager.initiate_authorization_code_flow(
+ gateway_id, gateway.oauth_config
+ )
+
+ # Redirect user to OAuth provider
+ return RedirectResponse(url=auth_data['authorization_url'])
+```
+
+@oauth_router.get("/callback")
+async def oauth_callback(
+ code: str,
+ state: str,
+ gateway_id: str,
+ request: Request,
+ db: Session = Depends(get_db)
+) -> HTMLResponse:
+ """Handle OAuth callback and complete authorization."""
+
+ try:
+ # Complete OAuth flow
+ oauth_manager = OAuthManager(TokenStorageService(db))
+ gateway = db.query(DbGateway).filter(DbGateway.id == gateway_id).first()
+
+ result = await oauth_manager.complete_authorization_code_flow(
+ gateway_id, code, state, gateway.oauth_config
+ )
+
+ # Return success page with option to return to admin
+ return HTMLResponse(content=f"""
+
+
+
OAuth Authorization Successful
+
+ ✅ OAuth Authorization Successful
+ Gateway: {gateway.name}
+ User: {result['user_id']}
+ Expires: {result['expires_at']}
+ Return to Admin Panel
+
+
+ """)
+
+ except Exception as e:
+ return HTMLResponse(content=f"""
+
+
+ OAuth Authorization Failed
+
+ ❌ OAuth Authorization Failed
+ Error: {str(e)}
+ Return to Admin Panel
+
+
+ """, status_code=400)
+```
+
+## UI Implementation
+
+### 1. Enhanced Gateway Creation Form
+
+**File**: `mcpgateway/templates/admin.html`
+
+```html
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+
+
+
+
+
+
+
+
+
+
+
+ This must match the redirect URI configured in your OAuth application
+
+
+
+
+
+
+
+```
+
+### 2. JavaScript for Dynamic Field Management
+
+**File**: `mcpgateway/static/admin.js`
+
+```javascript
+// OAuth field management
+function toggleOAuthFields() {
+ const grantType = document.getElementById('oauth-grant-type-gw').value;
+ const authCodeFields = document.getElementById('oauth-auth-code-fields-gw');
+
+ if (grantType === 'authorization_code') {
+ authCodeFields.style.display = 'block';
+ // Show additional validation for required fields
+ document.querySelectorAll('#oauth-auth-code-fields-gw input').forEach(input => {
+ input.required = true;
+ });
+ } else {
+ authCodeFields.style.display = 'none';
+ // Remove required validation for hidden fields
+ document.querySelectorAll('#oauth-auth-code-fields-gw input').forEach(input => {
+ input.required = false;
+ });
+ }
+}
+
+// Enhanced gateway form submission
+async function submitGatewayForm(formData) {
+ const grantType = formData.get('oauth_grant_type');
+
+ if (grantType === 'authorization_code') {
+ // Validate required fields
+ const requiredFields = [
+ 'oauth_authorization_url',
+ 'oauth_redirect_uri'
+ ];
+
+ for (const field of requiredFields) {
+ if (!formData.get(field)) {
+ showError(`Field ${field} is required for Authorization Code flow`);
+ return false;
+ }
+ }
+
+ // Check if redirect URI matches expected pattern
+ const redirectUri = formData.get('oauth_redirect_uri');
+ const expectedPattern = window.location.origin + '/oauth/callback';
+
+ if (!redirectUri.includes('/oauth/callback')) {
+ showWarning('Redirect URI should typically end with /oauth/callback for security');
+ }
+ }
+
+ return true;
+}
+```
+
+### 3. Gateway Management Interface
+
+**Enhanced Gateway List View**
+
+```html
+
+
+
+
+ Status:
+ {% if gateway.enabled %}
+
+ Active
+
+ {% else %}
+
+ Inactive
+
+ {% endif %}
+
+
+ {% if gateway.auth_type == 'oauth' and gateway.oauth_config %}
+
+ OAuth: {{ gateway.oauth_config.grant_type.replace('_', ' ').title() }}
+ {% if gateway.oauth_config.grant_type == 'authorization_code' %}
+
+
+ 🔐 Authorize Users
+
+ {% endif %}
+
+ {% endif %}
+
+ |
+```
+
+## OAuth Flow Sequences
+
+### Authorization Code Flow with Token Storage
+
+```mermaid
+sequenceDiagram
+ participant Admin
+ participant Gateway
+ participant OAuth Manager
+ participant Token Storage
+ participant OAuth Provider
+ participant Database
+
+ Admin->>Gateway: Configure OAuth (Auth Code)
+ Gateway->>Database: Store OAuth config
+
+ Admin->>Gateway: Click "Authorize Users"
+ Gateway->>OAuth Manager: Initiate auth flow
+ OAuth Manager->>Token Storage: Store auth state
+ OAuth Manager-->>Gateway: Authorization URL
+ Gateway-->>Admin: Redirect to OAuth Provider
+
+ Admin->>OAuth Provider: Login & Authorize
+ OAuth Provider-->>Gateway: Callback with code
+ Gateway->>OAuth Manager: Exchange code for tokens
+ OAuth Manager->>OAuth Provider: POST /token
+ OAuth Provider-->>OAuth Manager: Access + Refresh tokens
+
+ OAuth Manager->>Token Storage: Store encrypted tokens
+ Token Storage->>Database: Save token record
+ Token Storage-->>OAuth Manager: Confirmation
+ OAuth Manager-->>Gateway: Success
+ Gateway-->>Admin: Authorization complete
+```
+
+### Tool Invocation with Stored Tokens
+
+```mermaid
+sequenceDiagram
+ participant Client
+ participant Gateway
+ participant Token Storage
+ participant OAuth Provider
+ participant MCP Server
+
+ Client->>Gateway: Invoke Tool
+ Gateway->>Token Storage: Get valid access token
+ Token Storage->>Token Storage: Check expiration
+
+ alt Token valid
+ Token Storage-->>Gateway: Decrypted access token
+ Gateway->>MCP Server: Tool request + Bearer token
+ MCP Server-->>Gateway: Tool response
+ Gateway-->>Client: Result
+ else Token expired
+ Token Storage->>OAuth Provider: Refresh token request
+ OAuth Provider-->>Token Storage: New access token
+ Token Storage->>Database: Update token record
+ Token Storage-->>Gateway: New access token
+ Gateway->>MCP Server: Tool request + Bearer token
+ MCP Server-->>Gateway: Tool response
+ Gateway-->>Client: Result
+ end
+```
+
+## Security Considerations
+
+### 1. Token Security
+- **Encryption**: All tokens stored encrypted using `AUTH_ENCRYPTION_SECRET`
+- **Access Control**: Tokens only accessible to authorized gateway operations
+- **Audit Logging**: Track all token operations and access attempts
+
+### 2. OAuth Flow Security
+- **State Validation**: CSRF protection using state parameters
+- **Redirect URI Validation**: Strict validation of callback URLs
+- **Scope Limitation**: Request minimum required scopes only
+- **HTTPS Enforcement**: All OAuth endpoints require HTTPS
+
+### 3. Data Protection
+- **Token Expiration**: Automatic cleanup of expired tokens
+- **User Consent**: Clear indication of what permissions are granted
+- **Revocation Support**: Ability to revoke user access when needed
+
+## Configuration
+
+### Environment Variables
+
+```env
+# OAuth Token Management
+OAUTH_TOKEN_CLEANUP_INTERVAL=3600 # Token cleanup interval in seconds
+OAUTH_TOKEN_EXPIRY_THRESHOLD=300 # Refresh tokens this many seconds before expiry
+OAUTH_MAX_STORED_TOKENS_PER_GATEWAY=100 # Maximum tokens to store per gateway
+
+# Security
+OAUTH_STATE_EXPIRY=300 # Authorization state expiry in seconds
+OAUTH_MAX_AUTHORIZATION_ATTEMPTS=5 # Maximum failed authorization attempts
+```
+
+### Example Gateway Configuration
+
+```json
+{
+ "name": "GitHub MCP with User Delegation",
+ "url": "https://github-mcp.example.com/sse",
+ "auth_type": "oauth",
+ "oauth_config": {
+ "grant_type": "authorization_code",
+ "client_id": "your_github_app_id",
+ "client_secret": "your_github_app_secret",
+ "authorization_url": "https://github.com/login/oauth/authorize",
+ "token_url": "https://github.com/login/oauth/access_token",
+ "redirect_uri": "https://gateway.example.com/oauth/callback",
+ "scopes": ["repo", "read:user"],
+ "token_management": {
+ "store_tokens": true,
+ "auto_refresh": true,
+ "refresh_threshold_seconds": 300
+ }
+ }
+}
+```
+
+## Implementation Phases
+
+### Phase 1: Database and Core Services (Week 1)
+- [ ] Create `oauth_tokens` table
+- [ ] Implement `TokenStorageService`
+- [ ] Enhance `OAuthManager` with token storage
+- [ ] Add token refresh functionality
+
+### Phase 2: OAuth Router and Callback (Week 2)
+- [ ] Implement OAuth authorization router
+- [ ] Create callback handler with token storage
+- [ ] Add state management and validation
+- [ ] Implement user ID extraction from tokens
+
+### Phase 3: UI Enhancements (Week 3)
+- [ ] Update gateway creation form
+- [ ] Add dynamic field management
+- [ ] Implement authorization flow UI
+- [ ] Add token status display
+
+### Phase 4: Integration and Testing (Week 4)
+- [ ] Integrate with existing gateway service
+- [ ] Update tool invocation to use stored tokens
+- [ ] Comprehensive testing of OAuth flows
+- [ ] Security review and documentation
+
+## Testing Strategy
+
+### Unit Tests
+
+```python
+async def test_token_storage_service():
+ """Test token storage and retrieval."""
+ service = TokenStorageService(db)
+
+ # Test token storage
+ token_record = await service.store_tokens(
+ gateway_id="test_gateway",
+ user_id="test_user",
+ access_token="test_token",
+ refresh_token="test_refresh",
+ expires_in=3600,
+ scopes=["repo", "read:user"]
+ )
+
+ assert token_record is not None
+ assert token_record.user_id == "test_user"
+
+ # Test token retrieval
+ token = await service.get_valid_token("test_gateway", "test_user")
+ assert token == "test_token"
+
+async def test_authorization_code_flow():
+ """Test complete authorization code flow."""
+ oauth_manager = OAuthManager(TokenStorageService(db))
+
+ # Test flow initiation
+ auth_data = await oauth_manager.initiate_authorization_code_flow(
+ "test_gateway", test_credentials
+ )
+
+ assert "authorization_url" in auth_data
+ assert "state" in auth_data
+```
+
+### Integration Tests
+
+```python
+async def test_oauth_callback_end_to_end():
+ """Test complete OAuth callback flow."""
+ # Mock OAuth provider responses
+ with responses.RequestsMock() as rsps:
+ rsps.add(
+ responses.POST,
+ "https://oauth.example.com/token",
+ json={
+ "access_token": "test_access_token",
+ "refresh_token": "test_refresh_token",
+ "expires_in": 3600,
+ "scope": "repo read:user"
+ }
+ )
+
+ # Test callback
+ response = await client.get(
+ "/oauth/callback",
+ params={
+ "code": "test_code",
+ "state": "test_state",
+ "gateway_id": "test_gateway"
+ }
+ )
+
+ assert response.status_code == 200
+ assert "OAuth Authorization Successful" in response.text
+```
+
+## Future Enhancements
+
+### 1. Advanced Token Management
+- **Token Rotation**: Automatic token rotation for enhanced security
+- **Multi-User Support**: Support for multiple users per gateway
+- **Token Analytics**: Usage analytics and monitoring
+
+### 2. OAuth Provider Templates
+- **Pre-configured Providers**: Templates for GitHub, GitLab, etc.
+- **Provider-Specific Scopes**: Recommended scopes for common providers
+- **Auto-discovery**: OAuth provider metadata discovery
+
+### 3. Enhanced Security
+- **PKCE Support**: Proof Key for Code Exchange for public clients
+- **JWT Validation**: Support for JWT-based tokens
+- **Audit Trail**: Comprehensive audit logging for compliance
+
+## Conclusion
+
+This design provides a comprehensive framework for implementing OAuth 2.0 Authorization Code flow with user consent in the MCP Gateway UI. The implementation balances security, usability, and maintainability while extending the existing OAuth infrastructure.
+
+Key benefits of this approach:
+- **User Consent**: Proper user delegation with scoped permissions
+- **Token Efficiency**: Reuse of valid tokens with automatic refresh
+- **Security**: Encrypted token storage and comprehensive validation
+- **Scalability**: Support for multiple users and gateways
+- **Maintainability**: Clean separation of concerns and modular design
+
+The phased implementation approach ensures minimal disruption to existing functionality while delivering value incrementally. The design follows OAuth 2.0 best practices and provides a solid foundation for future enhancements.
diff --git a/docs/docs/architecture/oauth-design.md b/docs/docs/architecture/oauth-design.md
new file mode 100644
index 000000000..bbcc7e00f
--- /dev/null
+++ b/docs/docs/architecture/oauth-design.md
@@ -0,0 +1,421 @@
+# OAuth 2.0 Integration Design for MCP Gateway
+
+**Version**: 1.0
+**Status**: Draft
+**Date**: December 2024
+
+## Executive Summary
+
+This document outlines the design for integrating OAuth 2.0 authentication into the MCP Gateway, enabling agents to perform actions on behalf of users without requiring personal access tokens (PATs). The implementation will use the `oauthlib` library and support Client Credentials and Authorization Code flows following OAuth 2.0 best practices.
+
+## Motivation
+
+Current limitations:
+- Personal Access Tokens (PATs) provide broad access with security risks
+- Manual token management across multiple services
+- No native support for delegated authorization with scoped permissions
+
+OAuth 2.0 provides:
+- Standardized authentication flows
+- Scoped access control
+- Temporary access without storing user credentials
+- Industry-standard security practices
+
+## Architecture Overview
+
+```mermaid
+graph TD
+ subgraph "MCP Gateway"
+ A[Admin UI]
+ B[Gateway Service]
+ C[Tool Service]
+ D[OAuth Manager]
+ end
+
+ subgraph "Storage"
+ E[Database]
+ end
+
+ subgraph "External"
+ F[OAuth Provider]
+ G[MCP Server]
+ end
+
+ A --> E
+ B --> D
+ C --> D
+ D --> F
+ B --> G
+ C --> G
+
+ D -.->|Uses| H[oauthlib]
+```
+
+## Database Schema
+
+### Modified Gateway Table
+
+```sql
+ALTER TABLE gateways
+ADD COLUMN oauth_config JSON;
+
+-- OAuth config structure:
+{
+ "grant_type": "client_credentials|authorization_code",
+ "client_id": "string",
+ "client_secret": "encrypted_string",
+ "authorization_url": "string",
+ "token_url": "string",
+ "redirect_uri": "string",
+ "scopes": ["scope1", "scope2"]
+}
+```
+
+## Core Components
+
+### 1. OAuth Manager Service
+
+**Location**: `mcpgateway/services/oauth_manager.py`
+
+```python
+from oauthlib.oauth2 import BackendApplicationClient, WebApplicationClient
+from requests_oauthlib import OAuth2Session
+from typing import Optional, Dict, Any
+
+class OAuthManager:
+ """Manages OAuth 2.0 authentication flows."""
+
+ async def get_access_token(
+ self,
+ credentials: Dict[str, Any]
+ ) -> str:
+ """Get access token based on grant type."""
+ if credentials['grant_type'] == 'client_credentials':
+ return await self._client_credentials_flow(credentials)
+ elif credentials['grant_type'] == 'authorization_code':
+ return await self._authorization_code_flow(credentials)
+ else:
+ raise ValueError(f"Unsupported grant type: {credentials['grant_type']}")
+
+ async def _client_credentials_flow(
+ self,
+ credentials: Dict[str, Any]
+ ) -> str:
+ """Machine-to-machine authentication."""
+ client = BackendApplicationClient(client_id=credentials['client_id'])
+ oauth = OAuth2Session(client=client)
+
+ token = oauth.fetch_token(
+ token_url=credentials['token_url'],
+ client_id=credentials['client_id'],
+ client_secret=credentials['client_secret'],
+ scope=credentials.get('scopes', [])
+ )
+
+ return token['access_token']
+
+ async def _authorization_code_flow(
+ self,
+ credentials: Dict[str, Any]
+ ) -> Dict[str, str]:
+ """User delegation flow - returns authorization URL."""
+ oauth = OAuth2Session(
+ credentials['client_id'],
+ redirect_uri=credentials['redirect_uri'],
+ scope=credentials.get('scopes', [])
+ )
+
+ authorization_url, state = oauth.authorization_url(
+ credentials['authorization_url']
+ )
+
+ return {
+ 'authorization_url': authorization_url,
+ 'state': state
+ }
+
+ async def exchange_code_for_token(
+ self,
+ credentials: Dict[str, Any],
+ code: str,
+ state: str
+ ) -> str:
+ """Exchange authorization code for access token."""
+ oauth = OAuth2Session(
+ credentials['client_id'],
+ state=state,
+ redirect_uri=credentials['redirect_uri']
+ )
+
+ token = oauth.fetch_token(
+ credentials['token_url'],
+ client_secret=credentials['client_secret'],
+ authorization_response=f"{credentials['redirect_uri']}?code={code}&state={state}"
+ )
+
+ return token['access_token']
+```
+
+### 2. Admin UI OAuth Configuration
+
+```html
+
+```
+
+## Implementation Details
+
+### Gateway Service Integration
+
+**File**: `mcpgateway/services/gateway_service.py`
+
+```python
+async def _initialize_gateway(
+ self,
+ url: str,
+ authentication: Optional[Dict[str, str]] = None,
+ transport: str = "SSE"
+) -> tuple:
+ """Initialize gateway with OAuth support."""
+
+ headers = {}
+
+ if authentication and authentication.get('type') == 'oauth':
+ # Get OAuth credentials from database
+ gateway = await self._get_gateway(authentication['gateway_id'])
+ oauth_config = gateway.oauth_config
+
+ # Get access token
+ access_token = await self.oauth_manager.get_access_token(oauth_config)
+ headers = {'Authorization': f'Bearer {access_token}'}
+ else:
+ # Existing authentication logic
+ headers = decode_auth(authentication)
+
+ # Connect to MCP server
+ return await self._connect_to_gateway(url, headers, transport)
+```
+
+### Tool Service Integration
+
+**File**: `mcpgateway/services/tool_service.py`
+
+```python
+async def invoke_tool(
+ self,
+ db: Session,
+ name: str,
+ arguments: Dict[str, Any]
+) -> ToolResult:
+ """Invoke tool with OAuth support."""
+
+ tool = await self.get_tool_by_name(db, name)
+ headers = {}
+
+ if tool.gateway and tool.gateway.auth_type == 'oauth':
+ # Get fresh access token for each request
+ oauth_config = tool.gateway.oauth_config
+ access_token = await self.oauth_manager.get_access_token(oauth_config)
+ headers = {'Authorization': f'Bearer {access_token}'}
+ else:
+ # Existing authentication
+ headers = self._get_tool_headers(tool)
+
+ # Execute tool
+ return await self._execute_tool(tool, arguments, headers)
+```
+
+## OAuth Flow Sequences
+
+### Client Credentials Flow (M2M)
+
+```mermaid
+sequenceDiagram
+ participant Client
+ participant Gateway
+ participant OAuth Manager
+ participant OAuth Provider
+ participant MCP Server
+
+ Client->>Gateway: Configure OAuth (Client Credentials)
+ Client->>Gateway: Invoke Tool
+ Gateway->>OAuth Manager: Get Access Token
+ OAuth Manager->>OAuth Provider: POST /token (client_id, secret)
+ OAuth Provider-->>OAuth Manager: Access Token
+ OAuth Manager-->>Gateway: Access Token
+ Gateway->>MCP Server: Tool Request + Bearer Token
+ MCP Server-->>Gateway: Tool Response
+ Gateway-->>Client: Result
+```
+
+### Authorization Code Flow
+
+```mermaid
+sequenceDiagram
+ participant User
+ participant Gateway
+ participant OAuth Manager
+ participant OAuth Provider
+ participant MCP Server
+
+ User->>Gateway: Configure OAuth (Auth Code)
+ User->>Gateway: Request Authorization
+ Gateway->>OAuth Manager: Get Auth URL
+ OAuth Manager-->>Gateway: Authorization URL
+ Gateway-->>User: Redirect to OAuth Provider
+ User->>OAuth Provider: Login & Authorize
+ OAuth Provider-->>Gateway: Callback with Code
+ Gateway->>OAuth Manager: Exchange Code
+ OAuth Manager->>OAuth Provider: POST /token (code)
+ OAuth Provider-->>OAuth Manager: Access Token
+ OAuth Manager-->>Gateway: Access Token
+ Gateway->>MCP Server: Tool Request + Bearer Token
+ MCP Server-->>Gateway: Response
+ Gateway-->>User: Result
+```
+
+## Security Considerations
+
+1. **Token Storage**: Access tokens are never stored - requested fresh for each operation
+2. **Secret Encryption**: Client secrets encrypted using `AUTH_ENCRYPTION_SECRET`
+3. **HTTPS Required**: All OAuth endpoints must use HTTPS
+4. **Scope Validation**: Request minimum required scopes
+5. **Error Handling**: Comprehensive error handling for OAuth failures
+
+## Configuration
+
+### Environment Variables
+
+```env
+# OAuth Configuration
+OAUTH_REQUEST_TIMEOUT=30 # OAuth request timeout in seconds
+OAUTH_MAX_RETRIES=3 # Max retries for token requests
+
+# Encryption
+AUTH_ENCRYPTION_SECRET=your-secret-key # For encrypting client secrets
+```
+
+### Example Gateway Configuration
+
+```json
+{
+ "name": "GitHub MCP",
+ "url": "https://github-mcp.example.com/sse",
+ "auth_type": "oauth",
+ "oauth_config": {
+ "grant_type": "authorization_code",
+ "client_id": "your_github_app_id",
+ "client_secret": "your_github_app_secret",
+ "authorization_url": "https://github.com/login/oauth/authorize",
+ "token_url": "https://github.com/login/oauth/access_token",
+ "redirect_uri": "https://gateway.example.com/oauth/callback",
+ "scopes": ["repo", "read:user"]
+ }
+}
+```
+
+## Implementation Phases
+
+### Phase 1: Core OAuth Support (Week 1)
+- Implement OAuth Manager
+- Add database schema changes
+- Client Credentials flow
+
+### Phase 2: UI Integration (Week 2)
+- Admin UI OAuth configuration
+- Authorization Code flow
+- OAuth callback endpoint
+
+### Phase 3: Testing & Documentation (Week 3)
+- Integration tests
+- Security review
+- User documentation
+
+## Dependencies
+
+```toml
+# Add to pyproject.toml
+dependencies = [
+ "oauthlib>=3.2.2",
+ "requests-oauthlib>=1.3.1",
+ "cryptography>=41.0.0", # For secret encryption
+]
+```
+
+## Testing
+
+### Unit Tests
+
+```python
+async def test_client_credentials_flow():
+ oauth_manager = OAuthManager()
+ credentials = {
+ "grant_type": "client_credentials",
+ "client_id": "test_client",
+ "client_secret": "test_secret",
+ "token_url": "https://oauth.example.com/token"
+ }
+
+ token = await oauth_manager.get_access_token(credentials)
+ assert token is not None
+ assert isinstance(token, str)
+
+async def test_tool_invocation_with_oauth():
+ tool_service = ToolService(oauth_manager)
+ result = await tool_service.invoke_tool(
+ db=db,
+ name="github_create_issue",
+ arguments={"title": "Test Issue"}
+ )
+ assert result.success
+```
+
+## Future Enhancements
+
+1. **OAuth Provider Templates**: Pre-configured settings for common providers
+2. **Token Refresh**: Support refresh tokens for long-lived access
+3. **PKCE Support**: Add PKCE for public clients
+4. **Multiple OAuth Configs**: Support different OAuth configs per tool
+
+## Conclusion
+
+This OAuth 2.0 integration provides secure, standards-based authentication for MCP Gateway without the complexity of token caching. By requesting fresh tokens for each operation, we ensure simplicity while maintaining security. The implementation follows OAuth 2.0 best practices and enables seamless integration with various OAuth providers.
diff --git a/docs/oauth-setup.md b/docs/oauth-setup.md
new file mode 100644
index 000000000..70aac8a0f
--- /dev/null
+++ b/docs/oauth-setup.md
@@ -0,0 +1,179 @@
+# OAuth 2.0 Setup Guide for MCP Gateway
+
+This guide explains how to configure OAuth 2.0 authentication for federated gateways in MCP Gateway.
+
+## Overview
+
+MCP Gateway supports OAuth 2.0 authentication for connecting to external MCP servers that require OAuth-based authentication. This eliminates the need to store long-lived personal access tokens and provides secure, scoped access to external services.
+
+## Supported OAuth Flows
+
+### 1. Client Credentials Flow (Machine-to-Machine)
+- **Use Case**: Server-to-server communication where no user interaction is required
+- **Best For**: Automated services, background jobs, API integrations
+- **Configuration**: Requires client ID, client secret, and token URL
+
+### 2. Authorization Code Flow (User Delegation)
+- **Use Case**: User-authorized access to external services
+- **Best For**: User-specific integrations, services requiring user consent
+- **Configuration**: Requires additional authorization URL and redirect URI
+
+## Configuration
+
+### Environment Variables
+
+Add these to your `.env` file:
+
+```env
+# OAuth Configuration
+OAUTH_REQUEST_TIMEOUT=30 # OAuth request timeout in seconds
+OAUTH_MAX_RETRIES=3 # Max retries for token requests
+
+# Encryption (for client secrets)
+AUTH_ENCRYPTION_SECRET=your-secure-encryption-key # Must be at least 32 characters
+```
+
+### Gateway Configuration
+
+When adding a new gateway through the Admin UI:
+
+1. **Authentication Type**: Select "OAuth 2.0"
+2. **Grant Type**: Choose between "Client Credentials" or "Authorization Code"
+3. **Client ID**: Your OAuth application's client ID
+4. **Client Secret**: Your OAuth application's client secret (will be encrypted)
+5. **Token URL**: OAuth provider's token endpoint
+6. **Scopes**: Space-separated list of required scopes (e.g., "repo read:user")
+
+#### Authorization Code Flow Additional Fields
+
+If using Authorization Code flow, also configure:
+
+- **Authorization URL**: OAuth provider's authorization endpoint
+- **Redirect URI**: Callback URL (typically `https://your-gateway.com/oauth/callback`)
+
+## Example Configurations
+
+### GitHub OAuth App
+
+```json
+{
+ "grant_type": "authorization_code",
+ "client_id": "your_github_app_id",
+ "client_secret": "your_github_app_secret",
+ "authorization_url": "https://github.com/login/oauth/authorize",
+ "token_url": "https://github.com/login/oauth/access_token",
+ "redirect_uri": "https://gateway.example.com/oauth/callback",
+ "scopes": ["repo", "read:user"]
+}
+```
+
+### Generic OAuth Provider (Client Credentials)
+
+```json
+{
+ "grant_type": "client_credentials",
+ "client_id": "your_client_id",
+ "client_secret": "your_client_secret",
+ "token_url": "https://oauth.example.com/token",
+ "scopes": ["api:read", "api:write"]
+}
+```
+
+## Security Features
+
+### Client Secret Encryption
+- Client secrets are automatically encrypted using AES-256 encryption
+- Encryption key derived from `AUTH_ENCRYPTION_SECRET`
+- Secrets are never stored in plain text
+
+### Token Management
+- Access tokens are requested fresh for each operation
+- No token caching or storage
+- Automatic retry with exponential backoff on failures
+
+### HTTPS Enforcement
+- All OAuth endpoints must use HTTPS
+- Redirect URIs must use secure protocols
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Invalid Client Credentials**
+ - Verify client ID and secret are correct
+ - Ensure OAuth app is properly configured with the provider
+
+2. **Invalid Redirect URI**
+ - Check that redirect URI matches exactly what's configured in OAuth app
+ - Ensure protocol (http/https) matches
+
+3. **Scope Issues**
+ - Verify requested scopes are available for your OAuth app
+ - Check provider's scope documentation
+
+4. **Network Issues**
+ - Verify token and authorization URLs are accessible
+ - Check firewall and network configuration
+
+### Debug Logging
+
+Enable debug logging to troubleshoot OAuth issues:
+
+```env
+LOG_LEVEL=DEBUG
+```
+
+Look for OAuth-related log messages in the gateway logs.
+
+## API Endpoints
+
+### OAuth Callback
+- **URL**: `GET /oauth/callback`
+- **Purpose**: Handle authorization code exchange
+- **Parameters**: `code`, `state`, `gateway_id`
+- **Authentication**: Required
+
+## Testing
+
+### Test OAuth Configuration
+
+1. Configure OAuth settings in the Admin UI
+2. Test the gateway connection
+3. Verify tools/resources are accessible
+4. Check logs for OAuth-related messages
+
+### Unit Tests
+
+Run OAuth unit tests:
+
+```bash
+pytest tests/unit/mcpgateway/test_oauth_manager.py -v
+```
+
+## Best Practices
+
+1. **Use Strong Encryption Keys**: Generate a strong `AUTH_ENCRYPTION_SECRET`
+2. **Minimal Scopes**: Request only the scopes you actually need
+3. **Secure Storage**: Keep encryption keys secure and rotate regularly
+4. **Monitor Usage**: Watch for unusual OAuth activity
+5. **Regular Testing**: Test OAuth flows regularly to ensure they work
+
+## Migration from Personal Access Tokens
+
+To migrate from PAT-based authentication to OAuth:
+
+1. Create OAuth app with your service provider
+2. Configure OAuth settings in gateway
+3. Test connection and functionality
+4. Remove old PAT-based configuration
+5. Update any hardcoded authentication references
+
+## Support
+
+For OAuth-related issues:
+
+1. Check the troubleshooting section above
+2. Review gateway logs for error messages
+3. Verify OAuth provider configuration
+4. Test with a simple OAuth client first
+5. Check provider's OAuth documentation
diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py
index 8128ac691..22e3cd656 100644
--- a/mcpgateway/admin.py
+++ b/mcpgateway/admin.py
@@ -83,6 +83,7 @@
from mcpgateway.utils.create_jwt_token import get_jwt_token
from mcpgateway.utils.error_formatter import ErrorFormatter
from mcpgateway.utils.metadata_capture import MetadataCapture
+from mcpgateway.utils.oauth_encryption import get_oauth_encryption
from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
from mcpgateway.utils.retry_manager import ResilientHttpClient
from mcpgateway.utils.security_cookies import set_auth_cookie
@@ -1558,7 +1559,9 @@ async def admin_ui(
servers = [server.model_dump(by_alias=True) for server in await server_service.list_servers(db, include_inactive=include_inactive)]
resources = [resource.model_dump(by_alias=True) for resource in await resource_service.list_resources(db, include_inactive=include_inactive)]
prompts = [prompt.model_dump(by_alias=True) for prompt in await prompt_service.list_prompts(db, include_inactive=include_inactive)]
- gateways = [gateway.model_dump(by_alias=True) for gateway in await gateway_service.list_gateways(db, include_inactive=include_inactive)]
+ gateways_raw = await gateway_service.list_gateways(db, include_inactive=include_inactive)
+ gateways = [gateway.model_dump(by_alias=True) for gateway in gateways_raw]
+
roots = [root.model_dump(by_alias=True) for root in await root_service.list_roots()]
root_path = settings.app_root_path
max_name_length = settings.validation_max_name_length
@@ -2654,6 +2657,20 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use
except (json.JSONDecodeError, ValueError):
auth_headers = []
+ # Parse OAuth configuration if present
+ oauth_config_json = str(form.get("oauth_config"))
+ oauth_config: Optional[dict[str, Any]] = None
+ if oauth_config_json and oauth_config_json != "None":
+ try:
+ oauth_config = json.loads(oauth_config_json)
+ # Encrypt the client secret if present
+ if oauth_config and "client_secret" in oauth_config:
+ encryption = get_oauth_encryption(settings.auth_encryption_secret)
+ oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"])
+ except (json.JSONDecodeError, ValueError) as e:
+ LOGGER.error(f"Failed to parse OAuth config: {e}")
+ oauth_config = None
+
# Handle passthrough_headers
passthrough_headers = str(form.get("passthrough_headers"))
if passthrough_headers and passthrough_headers.strip():
@@ -2678,6 +2695,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use
auth_header_key=str(form.get("auth_header_key", "")),
auth_header_value=str(form.get("auth_header_value", "")),
auth_headers=auth_headers if auth_headers else None,
+ oauth_config=oauth_config,
passthrough_headers=passthrough_headers,
)
except KeyError as e:
@@ -2701,8 +2719,22 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use
created_via=metadata["created_via"],
created_user_agent=metadata["created_user_agent"],
)
+
+ # Provide specific guidance for OAuth Authorization Code flow
+ message = "Gateway registered successfully!"
+ if oauth_config and oauth_config.get("grant_type") == "authorization_code":
+ message = (
+ "Gateway registered successfully! 🎉\n\n"
+ "⚠️ IMPORTANT: This gateway uses OAuth Authorization Code flow.\n"
+ "You must complete the OAuth authorization before tools will work:\n\n"
+ "1. Go to the Gateways list\n"
+ "2. Click the '🔐 Authorize' button for this gateway\n"
+ "3. Complete the OAuth consent flow\n"
+ "4. Return to the admin panel\n\n"
+ "Tools will not work until OAuth authorization is completed."
+ )
return JSONResponse(
- content={"message": "Gateway registered successfully!", "success": True},
+ content={"message": message, "success": True},
status_code=200,
)
@@ -2720,6 +2752,10 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use
return JSONResponse(content={"message": str(ex), "success": False}, status_code=500)
+# OAuth callback is now handled by the dedicated OAuth router at /oauth/callback
+# This route has been removed to avoid conflicts with the complete implementation
+
+
@admin_router.post("/gateways/{gateway_id}/edit")
async def admin_edit_gateway(
gateway_id: str,
diff --git a/mcpgateway/alembic/versions/add_oauth_tokens_table.py b/mcpgateway/alembic/versions/add_oauth_tokens_table.py
new file mode 100644
index 000000000..5f0c6f11b
--- /dev/null
+++ b/mcpgateway/alembic/versions/add_oauth_tokens_table.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+"""add oauth tokens table
+
+Revision ID: add_oauth_tokens_table
+Revises: f8c9d3e2a1b4
+Create Date: 2024-12-20 11:00:00.000000
+
+"""
+
+# Standard
+from typing import Sequence, Union
+
+# Third-Party
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "add_oauth_tokens_table"
+down_revision: Union[str, Sequence[str], None] = "f8c9d3e2a1b4"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ """Add oauth_tokens table for storing OAuth access and refresh tokens."""
+ # Check if we're dealing with a fresh database
+ inspector = sa.inspect(op.get_bind())
+ tables = inspector.get_table_names()
+
+ if "gateways" not in tables:
+ print("Fresh database detected. Skipping migration.")
+ return
+
+ # Create oauth_tokens table
+ op.create_table(
+ "oauth_tokens",
+ sa.Column("id", sa.String(36), primary_key=True),
+ sa.Column("gateway_id", sa.String(36), nullable=False),
+ sa.Column("user_id", sa.String(255), nullable=False),
+ sa.Column("access_token", sa.Text, nullable=False),
+ sa.Column("refresh_token", sa.Text, nullable=True),
+ sa.Column("token_type", sa.String(50), nullable=True, default="Bearer"),
+ sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("scopes", sa.JSON, nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now()),
+ # Foreign key constraint
+ sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"], ondelete="CASCADE"),
+ # Unique constraint
+ sa.UniqueConstraint("gateway_id", "user_id", name="unique_gateway_user"),
+ )
+
+ # Create indexes for efficient token lookup
+ op.create_index("idx_oauth_tokens_gateway_user", "oauth_tokens", ["gateway_id", "user_id"])
+ op.create_index("idx_oauth_tokens_expires", "oauth_tokens", ["expires_at"])
+
+ print("Successfully created oauth_tokens table with indexes.")
+
+
+def downgrade() -> None:
+ """Remove oauth_tokens table."""
+ # Check if we're dealing with a fresh database
+ inspector = sa.inspect(op.get_bind())
+ tables = inspector.get_table_names()
+
+ if "oauth_tokens" not in tables:
+ print("oauth_tokens table not found. Skipping migration.")
+ return
+
+ # Remove oauth_tokens table
+ op.drop_table("oauth_tokens")
+
+ print("Successfully removed oauth_tokens table.")
diff --git a/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py b/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py
new file mode 100644
index 000000000..40388877f
--- /dev/null
+++ b/mcpgateway/alembic/versions/f8c9d3e2a1b4_add_oauth_config_to_gateways.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+"""add oauth config to gateways
+
+Revision ID: f8c9d3e2a1b4
+Revises: eb17fd368f9d
+Create Date: 2024-12-20 10:00:00.000000
+
+"""
+
+# Standard
+from typing import Sequence, Union
+
+# Third-Party
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "f8c9d3e2a1b4"
+down_revision: Union[str, Sequence[str], None] = "34492f99a0c4"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ """Add oauth_config column to gateways table."""
+ # Check if we're dealing with a fresh database
+ inspector = sa.inspect(op.get_bind())
+ tables = inspector.get_table_names()
+
+ if "gateways" not in tables:
+ print("Fresh database detected. Skipping migration.")
+ return
+
+ # Add oauth_config column
+ with op.batch_alter_table("gateways", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("oauth_config", sa.JSON(), nullable=True, comment="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes"))
+
+ print("Successfully added oauth_config column to gateways table.")
+
+
+def downgrade() -> None:
+ """Remove oauth_config column from gateways table."""
+ # Check if we're dealing with a fresh database
+ inspector = sa.inspect(op.get_bind())
+ tables = inspector.get_table_names()
+
+ if "gateways" not in tables:
+ print("Fresh database detected. Skipping migration.")
+ return
+
+ # Remove oauth_config column
+ with op.batch_alter_table("gateways", schema=None) as batch_op:
+ batch_op.drop_column("oauth_config")
+
+ print("Successfully removed oauth_config column from gateways table.")
diff --git a/mcpgateway/config.py b/mcpgateway/config.py
index 4629b3cdd..d36cb58ce 100644
--- a/mcpgateway/config.py
+++ b/mcpgateway/config.py
@@ -146,6 +146,10 @@ class Settings(BaseSettings):
# Encryption key phrase for auth storage
auth_encryption_secret: str = "my-test-salt"
+ # OAuth Configuration
+ oauth_request_timeout: int = Field(default=30, description="OAuth request timeout in seconds")
+ oauth_max_retries: int = Field(default=3, description="Maximum retries for OAuth token requests")
+
# UI/Admin Feature Flags
mcpgateway_ui_enabled: bool = False
mcpgateway_admin_api_enabled: bool = False
diff --git a/mcpgateway/db.py b/mcpgateway/db.py
index 90d723c85..bdf5ac592 100644
--- a/mcpgateway/db.py
+++ b/mcpgateway/db.py
@@ -1208,9 +1208,15 @@ class Gateway(Base):
# federated_prompts: Mapped[List["Prompt"]] = relationship(secondary=prompt_gateway_table, back_populates="federated_with")
# Authorizations
- auth_type: Mapped[Optional[str]] = mapped_column(default=None) # "basic", "bearer", "headers" or None
+ auth_type: Mapped[Optional[str]] = mapped_column(default=None) # "basic", "bearer", "headers", "oauth" or None
auth_value: Mapped[Optional[Dict[str, str]]] = mapped_column(JSON)
+ # OAuth configuration
+ oauth_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, comment="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes")
+
+ # Relationship with OAuth tokens
+ oauth_tokens: Mapped[List["OAuthToken"]] = relationship("OAuthToken", back_populates="gateway", cascade="all, delete-orphan")
+
@event.listens_for(Gateway, "after_update")
def update_tool_names_on_gateway_update(_mapper, connection, target):
@@ -1277,6 +1283,26 @@ class SessionMessageRecord(Base):
session: Mapped["SessionRecord"] = relationship("SessionRecord", back_populates="messages")
+class OAuthToken(Base):
+ """ORM model for OAuth access and refresh tokens."""
+
+ __tablename__ = "oauth_tokens"
+
+ id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
+ gateway_id: Mapped[str] = mapped_column(String(36), ForeignKey("gateways.id", ondelete="CASCADE"), nullable=False)
+ user_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ access_token: Mapped[str] = mapped_column(Text, nullable=False)
+ refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ token_type: Mapped[str] = mapped_column(String(50), default="Bearer")
+ expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ scopes: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
+
+ # Relationship with gateway
+ gateway: Mapped["Gateway"] = relationship("Gateway", back_populates="oauth_tokens")
+
+
# Event listeners for validation
def validate_tool_schema(mapper, connection, target):
"""
diff --git a/mcpgateway/main.py b/mcpgateway/main.py
index cd0277564..e38d97d66 100644
--- a/mcpgateway/main.py
+++ b/mcpgateway/main.py
@@ -3083,6 +3083,16 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user: str = Depends(r
app.include_router(export_import_router)
app.include_router(well_known_router)
+# Include OAuth router
+try:
+ # First-Party
+ from mcpgateway.routers.oauth_router import oauth_router
+
+ app.include_router(oauth_router)
+ logger.info("OAuth router included")
+except ImportError:
+ logger.debug("OAuth router not available")
+
# Include reverse proxy router if enabled
try:
# First-Party
diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py
new file mode 100644
index 000000000..3f139b6d1
--- /dev/null
+++ b/mcpgateway/routers/oauth_router.py
@@ -0,0 +1,394 @@
+# -*- coding: utf-8 -*-
+"""OAuth Router for MCP Gateway.
+
+This module handles OAuth 2.0 Authorization Code flow endpoints including:
+- Initiating OAuth flows
+- Handling OAuth callbacks
+- Token management
+"""
+
+# Standard
+import logging
+from typing import Any, Dict
+
+# Third-Party
+from fastapi import APIRouter, Depends, HTTPException, Query, Request
+from fastapi.responses import HTMLResponse, RedirectResponse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+# First-Party
+from mcpgateway.db import Gateway, get_db
+from mcpgateway.services.oauth_manager import OAuthError, OAuthManager
+from mcpgateway.services.token_storage_service import TokenStorageService
+
+logger = logging.getLogger(__name__)
+
+oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
+
+
+@oauth_router.get("/authorize/{gateway_id}")
+async def initiate_oauth_flow(gateway_id: str, request: Request, db: Session = Depends(get_db)) -> RedirectResponse:
+ """Initiates the OAuth 2.0 Authorization Code flow for a specified gateway.
+
+ This endpoint retrieves the OAuth configuration for the given gateway, validates that
+ the gateway supports the Authorization Code flow, and redirects the user to the OAuth
+ provider's authorization URL to begin the OAuth process.
+
+ Args:
+ gateway_id: The unique identifier of the gateway to authorize.
+ request: The FastAPI request object.
+ db: The database session dependency.
+
+ Returns:
+ A redirect response to the OAuth provider's authorization URL.
+
+ Raises:
+ HTTPException: If the gateway is not found, not configured for OAuth, or not using
+ the Authorization Code flow. If an unexpected error occurs during the initiation process.
+ """
+ try:
+ # Get gateway configuration
+ gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none()
+
+ if not gateway:
+ raise HTTPException(status_code=404, detail="Gateway not found")
+
+ if not gateway.oauth_config:
+ raise HTTPException(status_code=400, detail="Gateway is not configured for OAuth")
+
+ if gateway.oauth_config.get("grant_type") != "authorization_code":
+ raise HTTPException(status_code=400, detail="Gateway is not configured for Authorization Code flow")
+
+ # Initiate OAuth flow
+ oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
+ auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config)
+
+ logger.info(f"Initiated OAuth flow for gateway {gateway_id}")
+
+ # Redirect user to OAuth provider
+ return RedirectResponse(url=auth_data["authorization_url"])
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to initiate OAuth flow: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Failed to initiate OAuth flow: {str(e)}")
+
+
+@oauth_router.get("/callback")
+async def oauth_callback(
+ code: str = Query(..., description="Authorization code from OAuth provider"),
+ state: str = Query(..., description="State parameter for CSRF protection"),
+ # Remove the gateway_id parameter requirement
+ request: Request = None,
+ db: Session = Depends(get_db),
+) -> HTMLResponse:
+ """Handle the OAuth callback and complete the authorization process.
+
+ This endpoint is called by the OAuth provider after the user authorizes access.
+ It receives the authorization code and state parameters, verifies the state,
+ retrieves the corresponding gateway configuration, and exchanges the code for an access token.
+
+ Args:
+ code (str): The authorization code returned by the OAuth provider.
+ state (str): The state parameter for CSRF protection, which encodes the gateway ID.
+ request (Request): The incoming HTTP request object.
+ db (Session): The database session dependency.
+
+ Returns:
+ HTMLResponse: An HTML response indicating the result of the OAuth authorization process.
+ """
+
+ try:
+ # Extract gateway_id from state parameter
+ if "_" not in state:
+ return HTMLResponse(content="❌ Invalid state parameter
", status_code=400)
+
+ gateway_id = state.split("_")[0]
+
+ # Get gateway configuration
+ gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none()
+
+ if not gateway:
+ return HTMLResponse(
+ content="""
+
+
+ OAuth Authorization Failed
+
+ ❌ OAuth Authorization Failed
+ Error: Gateway not found
+ Return to Admin Panel
+
+
+ """,
+ status_code=404,
+ )
+
+ if not gateway.oauth_config:
+ return HTMLResponse(
+ content="""
+
+
+ OAuth Authorization Failed
+
+ ❌ OAuth Authorization Failed
+ Error: Gateway has no OAuth configuration
+ Return to Admin Panel
+
+
+ """,
+ status_code=400,
+ )
+
+ # Complete OAuth flow
+ oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
+
+ result = await oauth_manager.complete_authorization_code_flow(gateway_id, code, state, gateway.oauth_config)
+
+ logger.info(f"Completed OAuth flow for gateway {gateway_id}, user {result.get('user_id')}")
+
+ # Return success page with option to return to admin
+ return HTMLResponse(
+ content=f"""
+
+
+
+ OAuth Authorization Successful
+
+
+
+ ✅ OAuth Authorization Successful
+
+
Gateway: {gateway.name}
+
User ID: {result.get('user_id', 'Unknown')}
+
Expires: {result.get('expires_at', 'Unknown')}
+
Status: Authorization completed successfully
+
+
+
+
Next Steps:
+
Now that OAuth authorization is complete, you can fetch tools from the MCP server:
+
+
+
+
+ Return to Admin Panel
+
+
+
+
+ """
+ )
+
+ except OAuthError as e:
+ logger.error(f"OAuth callback failed: {str(e)}")
+ return HTMLResponse(
+ content=f"""
+
+
+
+ OAuth Authorization Failed
+
+
+
+ ❌ OAuth Authorization Failed
+ Error: {str(e)}
+ Please check your OAuth configuration and try again.
+ Return to Admin Panel
+
+
+ """,
+ status_code=400,
+ )
+
+ except Exception as e:
+ logger.error(f"Unexpected error in OAuth callback: {str(e)}")
+ return HTMLResponse(
+ content=f"""
+
+
+
+ OAuth Authorization Failed
+
+
+
+ ❌ OAuth Authorization Failed
+ Unexpected Error: {str(e)}
+ Please contact your administrator for assistance.
+ Return to Admin Panel
+
+
+ """,
+ status_code=500,
+ )
+
+
+@oauth_router.get("/status/{gateway_id}")
+async def get_oauth_status(gateway_id: str, db: Session = Depends(get_db)) -> dict:
+ """Get OAuth status for a gateway.
+
+ Args:
+ gateway_id: ID of the gateway
+ db: Database session
+
+ Returns:
+ OAuth status information
+
+ Raises:
+ HTTPException: If gateway not found or error retrieving status
+ """
+ try:
+ # Get gateway configuration
+ gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none()
+
+ if not gateway:
+ raise HTTPException(status_code=404, detail="Gateway not found")
+
+ if not gateway.oauth_config:
+ return {"oauth_enabled": False, "message": "Gateway is not configured for OAuth"}
+
+ # Get OAuth configuration info
+ oauth_config = gateway.oauth_config
+ grant_type = oauth_config.get("grant_type")
+
+ if grant_type == "authorization_code":
+ # For now, return basic info - in a real implementation you might want to
+ # show authorized users, token status, etc.
+ return {
+ "oauth_enabled": True,
+ "grant_type": grant_type,
+ "client_id": oauth_config.get("client_id"),
+ "scopes": oauth_config.get("scopes", []),
+ "authorization_url": oauth_config.get("authorization_url"),
+ "redirect_uri": oauth_config.get("redirect_uri"),
+ "message": "Gateway configured for Authorization Code flow",
+ }
+ else:
+ return {
+ "oauth_enabled": True,
+ "grant_type": grant_type,
+ "client_id": oauth_config.get("client_id"),
+ "scopes": oauth_config.get("scopes", []),
+ "message": f"Gateway configured for {grant_type} flow",
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to get OAuth status: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Failed to get OAuth status: {str(e)}")
+
+
+@oauth_router.post("/fetch-tools/{gateway_id}")
+async def fetch_tools_after_oauth(gateway_id: str, db: Session = Depends(get_db)) -> Dict[str, Any]:
+ """Fetch tools from MCP server after OAuth completion for Authorization Code flow.
+
+ Args:
+ gateway_id: ID of the gateway to fetch tools for
+ db: Database session
+
+ Returns:
+ Dict containing success status and message with number of tools fetched
+
+ Raises:
+ HTTPException: If fetching tools fails
+ """
+ try:
+ # First-Party
+ from mcpgateway.services.gateway_service import GatewayService
+
+ gateway_service = GatewayService()
+ result = await gateway_service.fetch_tools_after_oauth(db, gateway_id)
+ tools_count = len(result.get("tools", []))
+
+ return {"success": True, "message": f"Successfully fetched and created {tools_count} tools"}
+
+ except Exception as e:
+ logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}")
+ raise HTTPException(status_code=500, detail=f"Failed to fetch tools: {str(e)}")
diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py
index 24238937e..dff38dbc8 100644
--- a/mcpgateway/schemas.py
+++ b/mcpgateway/schemas.py
@@ -1844,7 +1844,7 @@ class GatewayCreate(BaseModel):
passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target")
# Authorizations
- auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, or none")
+ auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, oauth, or none")
# Fields for various types of authentication
auth_username: Optional[str] = Field(None, description="Username for basic authentication")
auth_password: Optional[str] = Field(None, description="Password for basic authentication")
@@ -1853,6 +1853,9 @@ class GatewayCreate(BaseModel):
auth_header_value: Optional[str] = Field(None, description="Value for custom headers authentication")
auth_headers: Optional[List[Dict[str, str]]] = Field(None, description="List of custom headers for authentication")
+ # OAuth 2.0 configuration
+ oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes")
+
# Adding `auth_value` as an alias for better access post-validation
auth_value: Optional[str] = Field(None, validate_default=True)
tags: Optional[List[str]] = Field(default_factory=list, description="Tags for categorizing the gateway")
@@ -2003,6 +2006,12 @@ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]:
return encode_auth({"Authorization": f"Bearer {token}"})
+ if auth_type == "oauth":
+ # For OAuth authentication, we don't encode anything here
+ # The OAuth configuration is handled separately in the oauth_config field
+ # This method is only called for traditional auth types
+ return None
+
if auth_type == "authheaders":
# Support both new multi-headers format and legacy single header format
auth_headers = data.get("auth_headers")
@@ -2056,7 +2065,7 @@ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]:
return encode_auth({header_key: header_value})
- raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, or headers.")
+ raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, oauth, or headers.")
class GatewayUpdate(BaseModelWithConfigDict):
@@ -2171,13 +2180,13 @@ def create_auth_value(cls, v, info):
return auth_value
@staticmethod
- def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]:
"""
Processes the input authentication fields and returns the correct auth_value.
This method is called based on the selected auth_type.
Args:
- values: Dict container auth information auth_type, auth_username, auth_password, auth_token, auth_header_key and auth_header_value
+ info: ValidationInfo containing auth fields
Returns:
dict: Encoded auth information
@@ -2186,12 +2195,13 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]:
ValueError: If auth type is invalid
"""
- auth_type = values.data.get("auth_type")
+ data = info.data
+ auth_type = data.get("auth_type")
if auth_type == "basic":
# For basic authentication, both username and password must be present
- username = values.data.get("auth_username")
- password = values.data.get("auth_password")
+ username = data.get("auth_username")
+ password = data.get("auth_password")
if not username or not password:
raise ValueError("For 'basic' auth, both 'auth_username' and 'auth_password' must be provided.")
@@ -2200,16 +2210,22 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if auth_type == "bearer":
# For bearer authentication, only token is required
- token = values.data.get("auth_token")
+ token = data.get("auth_token")
if not token:
raise ValueError("For 'bearer' auth, 'auth_token' must be provided.")
return encode_auth({"Authorization": f"Bearer {token}"})
+ if auth_type == "oauth":
+ # For OAuth authentication, we don't encode anything here
+ # The OAuth configuration is handled separately in the oauth_config field
+ # This method is only called for traditional auth types
+ return None
+
if auth_type == "authheaders":
# Support both new multi-headers format and legacy single header format
- auth_headers = values.data.get("auth_headers")
+ auth_headers = data.get("auth_headers")
if auth_headers and isinstance(auth_headers, list):
# New multi-headers format with enhanced validation
header_dict = {}
@@ -2252,15 +2268,15 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]:
return encode_auth(header_dict)
# Legacy single header format (backward compatibility)
- header_key = values.data.get("auth_header_key")
- header_value = values.data.get("auth_header_value")
+ header_key = data.get("auth_header_key")
+ header_value = data.get("auth_header_value")
if not header_key or not header_value:
raise ValueError("For 'headers' auth, either 'auth_headers' list or both 'auth_header_key' and 'auth_header_value' must be provided.")
return encode_auth({header_key: header_value})
- raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, or headers.")
+ raise ValueError("Invalid 'auth_type'. Must be one of: basic, bearer, oauth, or headers.")
class GatewayRead(BaseModelWithConfigDict):
@@ -2273,8 +2289,9 @@ class GatewayRead(BaseModelWithConfigDict):
- enabled status
- reachable status
- Last seen timestamp
- - Authentication type: basic, bearer, headers
+ - Authentication type: basic, bearer, headers, oauth
- Authentication value: username/password or token or custom headers
+ - OAuth configuration for OAuth 2.0 authentication
Auto Populated fields:
- Authentication username: for basic auth
@@ -2299,9 +2316,12 @@ class GatewayRead(BaseModelWithConfigDict):
passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target")
# Authorizations
- auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None")
+ auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers, oauth, or None")
auth_value: Optional[str] = Field(None, description="auth value: username/password or token or custom headers")
+ # OAuth 2.0 configuration
+ oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration including grant_type, client_id, encrypted client_secret, URLs, and scopes")
+
# auth_value will populate the following fields
auth_username: Optional[str] = Field(None, description="username for basic authentication")
auth_password: Optional[str] = Field(None, description="password for basic authentication")
@@ -2401,6 +2421,12 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]:
if auth_value_encoded == settings.masked_auth_value:
return values
+ # Handle OAuth authentication (no auth_value to decode)
+ if auth_type == "oauth":
+ # OAuth gateways don't have traditional auth_value to decode
+ # They use oauth_config instead
+ return values
+
auth_value = decode_auth(auth_value_encoded)
if auth_type == "basic":
auth = auth_value.get("Authorization")
diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py
index 79ab646d5..954841867 100644
--- a/mcpgateway/services/gateway_service.py
+++ b/mcpgateway/services/gateway_service.py
@@ -79,6 +79,7 @@
# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
from mcpgateway.services.logging_service import LoggingService
+from mcpgateway.services.oauth_manager import OAuthManager
from mcpgateway.services.tool_service import ToolService
from mcpgateway.utils.create_slug import slugify
from mcpgateway.utils.retry_manager import ResilientHttpClient
@@ -174,7 +175,7 @@ class GatewayConnectionError(GatewayError):
"""
-class GatewayService:
+class GatewayService: # pylint: disable=too-many-instance-attributes
"""Service for managing federated gateways.
Handles:
@@ -230,6 +231,7 @@ def __init__(self) -> None:
self._pending_responses = {}
self.tool_service = ToolService()
self._gateway_failure_counts: dict[str, int] = {}
+ self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3")))
# For health checks, we determine the leader instance.
self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
@@ -459,7 +461,8 @@ async def register_gateway(
header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
auth_value = encode_auth(header_dict) # Encode the dict for consistency
- capabilities, tools, resources, prompts = await self._initialize_gateway(normalized_url, auth_value, gateway.transport)
+ oauth_config = getattr(gateway, "oauth_config", None)
+ capabilities, tools, resources, prompts = await self._initialize_gateway(normalized_url, auth_value, gateway.transport, auth_type, oauth_config)
tools = [
DbTool(
@@ -535,6 +538,7 @@ async def register_gateway(
last_seen=datetime.now(timezone.utc),
auth_type=auth_type,
auth_value=auth_value,
+ oauth_config=oauth_config,
tools=tools,
resources=db_resources,
prompts=db_prompts,
@@ -589,6 +593,98 @@ async def register_gateway(
logger.error(f"Other grouped errors: {other.exceptions}")
raise other.exceptions[0]
+ async def fetch_tools_after_oauth(self, db: Session, gateway_id: str) -> Dict[str, Any]:
+ """Fetch tools from MCP server after OAuth completion for Authorization Code flow.
+
+ Args:
+ db: Database session
+ gateway_id: ID of the gateway to fetch tools for
+
+ Returns:
+ Dict containing capabilities, tools, resources, and prompts
+
+ Raises:
+ GatewayConnectionError: If connection or OAuth fails
+ """
+ try:
+ # Get the gateway
+ gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
+
+ if not gateway:
+ raise ValueError(f"Gateway {gateway_id} not found")
+
+ if not gateway.oauth_config:
+ raise ValueError(f"Gateway {gateway_id} has no OAuth configuration")
+
+ grant_type = gateway.oauth_config.get("grant_type")
+ if grant_type != "authorization_code":
+ raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow")
+
+ # Get OAuth tokens for this gateway
+ # First-Party
+ from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
+
+ token_storage = TokenStorageService(db)
+
+ # Try to get a valid token for any user (for now, we'll use a placeholder)
+ # In a real implementation, you might want to specify which user's tokens to use
+ access_token = await token_storage.get_any_valid_token(gateway.id)
+
+ if not access_token:
+ raise GatewayConnectionError(f"No valid OAuth tokens found for gateway {gateway.name}. " "Please complete the OAuth authorization flow first.")
+
+ # Now connect to MCP server with the access token
+ authentication = {"Authorization": f"Bearer {access_token}"}
+
+ # Use the existing connection logic
+ if gateway.transport.upper() == "SSE":
+ capabilities, tools, resources, prompts = await self.connect_to_sse_server(gateway.url, authentication)
+ return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}
+ if gateway.transport.upper() == "STREAMABLEHTTP":
+ capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication)
+
+ # Filter out any None tools and create DbTool objects
+ tools_to_add = []
+ for tool in tools:
+ if tool is None:
+ logger.warning("Skipping None tool in tools list")
+ continue
+
+ try:
+ db_tool = DbTool(
+ original_name=tool.name,
+ original_name_slug=slugify(tool.name),
+ url=gateway.url.rstrip("/"),
+ description=tool.description,
+ integration_type="MCP", # Gateway-discovered tools are MCP type
+ request_type=tool.request_type,
+ headers=tool.headers,
+ input_schema=tool.input_schema,
+ annotations=tool.annotations,
+ jsonpath_filter=tool.jsonpath_filter,
+ auth_type=gateway.auth_type,
+ auth_value=gateway.auth_value,
+ gateway=gateway, # attach relationship to avoid NoneType during flush
+ )
+ tools_to_add.append(db_tool)
+ except Exception as e:
+ logger.warning(f"Failed to create DbTool for tool {getattr(tool, 'name', 'unknown')}: {e}")
+ continue
+
+ # Add to DB
+ if tools_to_add:
+ db.add_all(tools_to_add)
+ db.commit()
+ else:
+ logger.warning("No valid tools to add to database")
+
+ return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}
+ raise ValueError(f"Unsupported transport type: {gateway.transport}")
+
+ except Exception as e:
+ logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}")
+ raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}")
+
async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
"""List all registered gateways.
@@ -632,6 +728,21 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li
query = query.where(DbGateway.enabled)
gateways = db.execute(query).scalars().all()
+
+ # print("******************************************************************")
+ # for g in gateways:
+ # print("----------------------------")
+ # for attr in dir(g):
+ # if not attr.startswith("_"):
+ # try:
+ # value = getattr(g, attr)
+ # except Exception:
+ # value = ""
+ # print(f"{attr}: {value}")
+ # # print(f"Gateway oauth_config: {g}")
+ # # print(f"Gateway auth_type: {g['auth_type']}")
+ # print("******************************************************************")
+
return [GatewayRead.model_validate(g).masked() for g in gateways]
async def update_gateway(self, db: Session, gateway_id: str, gateway_update: GatewayUpdate, include_inactive: bool = True) -> GatewayRead:
@@ -710,7 +821,7 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
# Try to reinitialize connection if URL changed
if gateway_update.url is not None:
try:
- capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
+ capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config)
new_tool_names = [tool.name for tool in tools]
new_resource_uris = [resource.uri for resource in resources]
new_prompt_names = [prompt.name for prompt in prompts]
@@ -897,7 +1008,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo
self._active_gateways.add(gateway.url)
# Try to initialize if activating
try:
- capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
+ capabilities, tools, resources, prompts = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config)
new_tool_names = [tool.name for tool in tools]
new_resource_uris = [resource.uri for resource in resources]
new_prompt_names = [prompt.name for prompt in prompts]
@@ -1410,7 +1521,7 @@ async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
self._event_subscribers.remove(queue)
async def _initialize_gateway(
- self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE"
+ self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE", auth_type: Optional[str] = None, oauth_config: Optional[Dict[str, Any]] = None
) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
"""Initialize connection to a gateway and retrieve its capabilities.
@@ -1422,6 +1533,8 @@ async def _initialize_gateway(
url: Gateway URL to connect to
authentication: Optional authentication headers for the connection
transport: Transport protocol - "SSE" or "StreamableHTTP"
+ auth_type: Authentication type - "basic", "bearer", "headers", "oauth" or None
+ oauth_config: OAuth configuration if auth_type is "oauth"
Returns:
tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
@@ -1457,212 +1570,38 @@ async def _initialize_gateway(
if authentication is None:
authentication = {}
- async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
- """Connect to an MCP server running with SSE transport.
-
- Establishes an SSE connection to the MCP server, performs the
- initialization handshake, and retrieves server capabilities,
- tools, resources, and prompts.
-
- Args:
- server_url: URL to connect to the SSE-enabled MCP server
- authentication: Optional authentication headers for the connection
-
- Returns:
- Tuple[Dict, List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
- Server capabilities, list of tools, resources, and prompts
-
- Examples:
- >>> # Test function signature and defaults
- >>> import inspect
- >>> sig = inspect.signature(connect_to_sse_server)
- >>> list(sig.parameters.keys())
- ['server_url', 'authentication']
- >>> sig.parameters['authentication'].default is None
- True
- >>> sig.parameters['server_url'].annotation
-
-
- >>> # Test authentication parameter handling
- >>> auth = {"Authorization": "Bearer token123"}
- >>> isinstance(auth, dict)
- True
- >>> auth.get("Authorization", "").startswith("Bearer")
- True
- """
- if authentication is None:
- authentication = {}
- # Store the context managers so they stay alive
- decoded_auth = decode_auth(authentication)
-
- if await self._validate_gateway_url(url=server_url, headers=decoded_auth, transport_type="SSE"):
- # Use async with for both sse_client and ClientSession
- async with sse_client(url=server_url, headers=decoded_auth) as streams:
- async with ClientSession(*streams) as session:
- # Initialize the session
- response = await session.initialize()
- capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
- logger.debug(f"Server capabilities: {capabilities}")
-
- response = await session.list_tools()
- tools = response.tools
- tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
-
- tools = [ToolCreate.model_validate(tool) for tool in tools]
- if tools:
- logger.info(f"Fetched {len(tools)} tools from gateway")
-
- # Fetch resources if supported
- resources = []
- logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
- if capabilities.get("resources"):
- try:
- response = await session.list_resources()
- raw_resources = response.resources
- for resource in raw_resources:
- resource_data = resource.model_dump(by_alias=True, exclude_none=True)
- # Convert AnyUrl to string if present
- if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
- resource_data["uri"] = str(resource_data["uri"])
- # Add default content if not present (will be fetched on demand)
- if "content" not in resource_data:
- resource_data["content"] = ""
- try:
- resources.append(ResourceCreate.model_validate(resource_data))
- except Exception:
- # If validation fails, create minimal resource
- resources.append(
- ResourceCreate(
- uri=str(resource_data.get("uri", "")),
- name=resource_data.get("name", ""),
- description=resource_data.get("description"),
- mime_type=resource_data.get("mime_type"),
- template=resource_data.get("template"),
- content="",
- )
- )
- logger.info(f"Fetched {len(resources)} resources from gateway")
- except Exception as e:
- logger.warning(f"Failed to fetch resources: {e}")
-
- # Fetch prompts if supported
- prompts = []
- logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
- if capabilities.get("prompts"):
- try:
- response = await session.list_prompts()
- raw_prompts = response.prompts
- for prompt in raw_prompts:
- prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
- # Add default template if not present
- if "template" not in prompt_data:
- prompt_data["template"] = ""
- try:
- prompts.append(PromptCreate.model_validate(prompt_data))
- except Exception:
- # If validation fails, create minimal prompt
- prompts.append(
- PromptCreate(
- name=prompt_data.get("name", ""),
- description=prompt_data.get("description"),
- template=prompt_data.get("template", ""),
- )
- )
- logger.info(f"Fetched {len(prompts)} prompts from gateway")
- except Exception as e:
- logger.warning(f"Failed to fetch prompts: {e}")
-
- return capabilities, tools, resources, prompts
- raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
-
- async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
- """Connect to an MCP server running with Streamable HTTP transport.
-
- Establishes a StreamableHTTP connection to the MCP server, performs the
- initialization handshake, and retrieves server capabilities,
- tools, resources, and prompts.
-
- Args:
- server_url: URL to connect to the server
- authentication: Authentication headers for connection to URL
+ # Handle OAuth authentication
+ if auth_type == "oauth" and oauth_config:
+ grant_type = oauth_config.get("grant_type", "client_credentials")
- Returns:
- Tuple[Dict, List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
- Server capabilities, list of tools, resources, and prompts
- """
- if authentication is None:
+ if grant_type == "authorization_code":
+ # For Authorization Code flow, we can't initialize immediately
+ # because we need user consent. Just store the configuration
+ # and let the user complete the OAuth flow later.
+ logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""")
+ # Don't try to get access token here - it will be obtained during tool invocation
authentication = {}
- # Store the context managers so they stay alive
- decoded_auth = decode_auth(authentication)
- # The _validate_gateway_url logic is flawed for streamablehttp, so we bypass it
- # and go straight to the client connection. The outer try/except in
- # _initialize_gateway will handle any connection errors.
- async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
- async with ClientSession(read_stream, write_stream) as session:
- # Initialize the session
- response = await session.initialize()
- capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
- logger.debug(f"Server capabilities: {capabilities}")
-
- response = await session.list_tools()
- tools = response.tools
- tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
-
- tools = [ToolCreate.model_validate(tool) for tool in tools]
- for tool in tools:
- tool.request_type = "STREAMABLEHTTP"
- if tools:
- logger.info(f"Fetched {len(tools)} tools from gateway")
-
- # Fetch resources if supported
- resources = []
- logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
- if capabilities.get("resources"):
- try:
- response = await session.list_resources()
- raw_resources = response.resources
- resources = []
- for resource in raw_resources:
- resource_data = resource.model_dump(by_alias=True, exclude_none=True)
- # Convert AnyUrl to string if present
- if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
- resource_data["uri"] = str(resource_data["uri"])
- # Add default content if not present
- if "content" not in resource_data:
- resource_data["content"] = ""
- resources.append(ResourceCreate.model_validate(resource_data))
- logger.info(f"Fetched {len(resources)} resources from gateway")
- except Exception as e:
- logger.warning(f"Failed to fetch resources: {e}")
-
- # Fetch prompts if supported
- prompts = []
- logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
- if capabilities.get("prompts"):
- try:
- response = await session.list_prompts()
- raw_prompts = response.prompts
- prompts = []
- for prompt in raw_prompts:
- prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
- # Add default template if not present
- if "template" not in prompt_data:
- prompt_data["template"] = ""
- prompts.append(PromptCreate.model_validate(prompt_data))
- logger.info(f"Fetched {len(prompts)} prompts from gateway")
- except Exception as e:
- logger.warning(f"Failed to fetch prompts: {e}")
- return capabilities, tools, resources, prompts
+ # Skip MCP server connection for Authorization Code flow
+ # Tools will be fetched after OAuth completion
+ return {}, [], [], []
+ # For Client Credentials flow, we can get the token immediately
+ try:
+ print(f"oauth_config: {oauth_config}")
+ access_token = await self.oauth_manager.get_access_token(oauth_config)
+ authentication = {"Authorization": f"Bearer {access_token}"}
+ except Exception as e:
+ logger.error(f"Failed to obtain OAuth access token: {e}")
+ raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}")
capabilities = {}
tools = []
resources = []
prompts = []
if transport.lower() == "sse":
- capabilities, tools, resources, prompts = await connect_to_sse_server(url, authentication)
+ capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication)
elif transport.lower() == "streamablehttp":
- capabilities, tools, resources, prompts = await connect_to_streamablehttp_server(url, authentication)
+ capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(url, authentication)
return capabilities, tools, resources, prompts
except Exception as e:
@@ -1920,3 +1859,171 @@ async def _publish_event(self, event: Dict[str, Any]) -> None:
"""
for queue in self._event_subscribers:
await queue.put(event)
+
+ async def connect_to_sse_server(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
+ """Connect to an MCP server running with SSE transport.
+
+ Args:
+ server_url: The URL of the SSE MCP server to connect to.
+ authentication: Optional dictionary containing authentication headers.
+
+ Returns:
+ Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
+ """
+ if authentication is None:
+ authentication = {}
+ # Use authentication directly instead
+
+ if await self._validate_gateway_url(url=server_url, headers=authentication, transport_type="SSE"):
+ # Use async with for both sse_client and ClientSession
+ async with sse_client(url=server_url, headers=authentication) as streams:
+ async with ClientSession(*streams) as session:
+ # Initialize the session
+ response = await session.initialize()
+ capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
+ logger.debug(f"Server capabilities: {capabilities}")
+
+ response = await session.list_tools()
+ tools = response.tools
+ tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
+
+ tools = [ToolCreate.model_validate(tool) for tool in tools]
+ if tools:
+ logger.info(f"Fetched {len(tools)} tools from gateway")
+
+ # Fetch resources if supported
+ resources = []
+ logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
+ if capabilities.get("resources"):
+ try:
+ response = await session.list_resources()
+ raw_resources = response.resources
+ for resource in raw_resources:
+ resource_data = resource.model_dump(by_alias=True, exclude_none=True)
+ # Convert AnyUrl to string if present
+ if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
+ resource_data["uri"] = str(resource_data["uri"])
+ # Add default content if not present (will be fetched on demand)
+ if "content" not in resource_data:
+ resource_data["content"] = ""
+ try:
+ resources.append(ResourceCreate.model_validate(resource_data))
+ except Exception:
+ # If validation fails, create minimal resource
+ resources.append(
+ ResourceCreate(
+ uri=str(resource_data.get("uri", "")),
+ name=resource_data.get("name", ""),
+ description=resource_data.get("description"),
+ mime_type=resource_data.get("mime_type"),
+ template=resource_data.get("template"),
+ content="",
+ )
+ )
+ logger.info(f"Fetched {len(resources)} resources from gateway")
+ except Exception as e:
+ logger.warning(f"Failed to fetch resources: {e}")
+
+ # Fetch prompts if supported
+ prompts = []
+ logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
+ if capabilities.get("prompts"):
+ try:
+ response = await session.list_prompts()
+ raw_prompts = response.prompts
+ for prompt in raw_prompts:
+ prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
+ # Add default template if not present
+ if "template" not in prompt_data:
+ prompt_data["template"] = ""
+ try:
+ prompts.append(PromptCreate.model_validate(prompt_data))
+ except Exception:
+ # If validation fails, create minimal prompt
+ prompts.append(
+ PromptCreate(
+ name=prompt_data.get("name", ""),
+ description=prompt_data.get("description"),
+ template=prompt_data.get("template", ""),
+ )
+ )
+ logger.info(f"Fetched {len(prompts)} prompts from gateway")
+ except Exception as e:
+ logger.warning(f"Failed to fetch prompts: {e}")
+
+ return capabilities, tools, resources, prompts
+ raise GatewayConnectionError(f"Failed to initialize gateway at {server_url}")
+
+ async def connect_to_streamablehttp_server(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
+ """Connect to an MCP server running with Streamable HTTP transport.
+
+ Args:
+ server_url: The URL of the Streamable HTTP MCP server to connect to.
+ authentication: Optional dictionary containing authentication headers.
+
+ Returns:
+ Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
+ """
+ if authentication is None:
+ authentication = {}
+ # Use authentication directly instead
+
+ # The _validate_gateway_url logic is flawed for streamablehttp, so we bypass it
+ # and go straight to the client connection. The outer try/except in
+ # _initialize_gateway will handle any connection errors.
+ async with streamablehttp_client(url=server_url, headers=authentication) as (read_stream, write_stream, _get_session_id):
+ async with ClientSession(read_stream, write_stream) as session:
+ # Initialize the session
+ response = await session.initialize()
+ capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
+ logger.debug(f"Server capabilities: {capabilities}")
+
+ response = await session.list_tools()
+ tools = response.tools
+ tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
+
+ tools = [ToolCreate.model_validate(tool) for tool in tools]
+ for tool in tools:
+ tool.request_type = "STREAMABLEHTTP"
+ if tools:
+ logger.info(f"Fetched {len(tools)} tools from gateway")
+
+ # Fetch resources if supported
+ resources = []
+ logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
+ if capabilities.get("resources"):
+ try:
+ response = await session.list_resources()
+ raw_resources = response.resources
+ resources = []
+ for resource in raw_resources:
+ resource_data = resource.model_dump(by_alias=True, exclude_none=True)
+ # Convert AnyUrl to string if present
+ if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
+ resource_data["uri"] = str(resource_data["uri"])
+ # Add default content if not present
+ if "content" not in resource_data:
+ resource_data["content"] = ""
+ resources.append(ResourceCreate.model_validate(resource_data))
+ logger.info(f"Fetched {len(resources)} resources from gateway")
+ except Exception as e:
+ logger.warning(f"Failed to fetch resources: {e}")
+
+ # Fetch prompts if supported
+ prompts = []
+ logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
+ if capabilities.get("prompts"):
+ try:
+ response = await session.list_prompts()
+ raw_prompts = response.prompts
+ prompts = []
+ for prompt in raw_prompts:
+ prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
+ # Add default template if not present
+ if "template" not in prompt_data:
+ prompt_data["template"] = ""
+ prompts.append(PromptCreate.model_validate(prompt_data))
+ except Exception as e:
+ logger.warning(f"Failed to fetch prompts: {e}")
+
+ return capabilities, tools, resources, prompts
diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py
new file mode 100644
index 000000000..f7a1dc29c
--- /dev/null
+++ b/mcpgateway/services/oauth_manager.py
@@ -0,0 +1,511 @@
+# -*- coding: utf-8 -*-
+"""OAuth 2.0 Manager for MCP Gateway.
+
+This module handles OAuth 2.0 authentication flows including:
+- Client Credentials (Machine-to-Machine)
+- Authorization Code (User Delegation)
+"""
+
+# Standard
+import asyncio
+import logging
+import secrets
+from typing import Any, Dict, Optional
+
+# Third-Party
+import aiohttp
+from requests_oauthlib import OAuth2Session
+
+# First-Party
+from mcpgateway.config import get_settings
+from mcpgateway.utils.oauth_encryption import get_oauth_encryption
+
+logger = logging.getLogger(__name__)
+
+
+class OAuthManager:
+ """Manages OAuth 2.0 authentication flows."""
+
+ def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None):
+ """Initialize OAuth Manager.
+
+ Args:
+ request_timeout: Timeout for OAuth requests in seconds
+ max_retries: Maximum number of retry attempts for token requests
+ token_storage: Optional TokenStorageService for storing tokens
+ """
+ self.request_timeout = request_timeout
+ self.max_retries = max_retries
+ self.token_storage = token_storage
+
+ async def get_access_token(self, credentials: Dict[str, Any]) -> str:
+ """Get access token based on grant type.
+
+ Args:
+ credentials: OAuth configuration containing grant_type and other params
+
+ Returns:
+ Access token string
+
+ Raises:
+ ValueError: If grant type is unsupported
+ OAuthError: If token acquisition fails
+ """
+ grant_type = credentials.get("grant_type")
+ logger.debug(f"Getting access token for grant type: {grant_type}")
+
+ if grant_type == "client_credentials":
+ return await self._client_credentials_flow(credentials)
+ if grant_type == "authorization_code":
+ # For authorization code flow in gateway initialization, we need to handle this differently
+ # Since this is called during gateway setup, we'll try to use client credentials as fallback
+ # or provide a more helpful error message
+ logger.warning("Authorization code flow requires user interaction. " + "For gateway initialization, consider using 'client_credentials' grant type instead.")
+ # Try to use client credentials flow if possible (some OAuth providers support this)
+ try:
+ return await self._client_credentials_flow(credentials)
+ except Exception as e:
+ raise OAuthError(
+ f"Authorization code flow cannot be used for automatic gateway initialization. "
+ f"Please use 'client_credentials' grant type or complete the OAuth flow manually first. "
+ f"Error: {str(e)}"
+ )
+ else:
+ raise ValueError(f"Unsupported grant type: {grant_type}")
+
+ async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str:
+ """Machine-to-machine authentication using client credentials.
+
+ Args:
+ credentials: OAuth configuration with client_id, client_secret, token_url
+
+ Returns:
+ Access token string
+
+ Raises:
+ OAuthError: If token acquisition fails after all retries
+ """
+ client_id = credentials["client_id"]
+ client_secret = credentials["client_secret"]
+ token_url = credentials["token_url"]
+ scopes = credentials.get("scopes", [])
+
+ # Decrypt client secret if it's encrypted
+ if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
+ try:
+ settings = get_settings()
+ encryption = get_oauth_encryption(settings.auth_encryption_secret)
+ decrypted_secret = encryption.decrypt_secret(client_secret)
+ if decrypted_secret:
+ client_secret = decrypted_secret
+ logger.debug("Successfully decrypted client secret")
+ else:
+ logger.warning("Failed to decrypt client secret, using encrypted version")
+ except Exception as e:
+ logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
+
+ # Prepare token request data
+ token_data = {
+ "grant_type": "client_credentials",
+ "client_id": client_id,
+ "client_secret": client_secret,
+ }
+
+ if scopes:
+ token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
+
+ # Fetch token with retries
+ for attempt in range(self.max_retries):
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response:
+ response.raise_for_status()
+
+ # GitHub returns form-encoded responses, not JSON
+ content_type = response.headers.get("content-type", "")
+ if "application/x-www-form-urlencoded" in content_type:
+ # Parse form-encoded response
+ text_response = await response.text()
+ token_response = {}
+ for pair in text_response.split("&"):
+ if "=" in pair:
+ key, value = pair.split("=", 1)
+ token_response[key] = value
+ else:
+ # Try JSON response
+ try:
+ token_response = await response.json()
+ except Exception as e:
+ logger.warning(f"Failed to parse JSON response: {e}")
+ # Fallback to text parsing
+ text_response = await response.text()
+ token_response = {"raw_response": text_response}
+
+ if "access_token" not in token_response:
+ raise OAuthError(f"No access_token in response: {token_response}")
+
+ logger.info("""Successfully obtained access token via client credentials""")
+ return token_response["access_token"]
+
+ except aiohttp.ClientError as e:
+ logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
+ if attempt == self.max_retries - 1:
+ raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
+ await asyncio.sleep(2**attempt) # Exponential backoff
+
+ # This should never be reached due to the exception above, but needed for type safety
+ raise OAuthError("Failed to obtain access token after all retry attempts")
+
+ async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]:
+ """Get authorization URL for user delegation flow.
+
+ Args:
+ credentials: OAuth configuration with client_id, authorization_url, etc.
+
+ Returns:
+ Dict containing authorization_url and state
+ """
+ client_id = credentials["client_id"]
+ redirect_uri = credentials["redirect_uri"]
+ authorization_url = credentials["authorization_url"]
+ scopes = credentials.get("scopes", [])
+
+ # Create OAuth2 session
+ oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
+
+ # Generate authorization URL with state for CSRF protection
+ auth_url, state = oauth.authorization_url(authorization_url)
+
+ logger.info(f"Generated authorization URL for client {client_id}")
+
+ return {"authorization_url": auth_url, "state": state}
+
+ async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument
+ """Exchange authorization code for access token.
+
+ Args:
+ credentials: OAuth configuration
+ code: Authorization code from callback
+ state: State parameter for CSRF validation
+
+ Returns:
+ Access token string
+
+ Raises:
+ OAuthError: If token exchange fails
+ """
+ client_id = credentials["client_id"]
+ client_secret = credentials["client_secret"]
+ token_url = credentials["token_url"]
+ redirect_uri = credentials["redirect_uri"]
+
+ # Decrypt client secret if it's encrypted
+ if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
+ try:
+ settings = get_settings()
+ encryption = get_oauth_encryption(settings.auth_encryption_secret)
+ decrypted_secret = encryption.decrypt_secret(client_secret)
+ if decrypted_secret:
+ client_secret = decrypted_secret
+ logger.debug("Successfully decrypted client secret")
+ else:
+ logger.warning("Failed to decrypt client secret, using encrypted version")
+ except Exception as e:
+ logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
+
+ # Prepare token exchange data
+ token_data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": redirect_uri,
+ "client_id": client_id,
+ "client_secret": client_secret,
+ }
+
+ # Exchange code for token with retries
+ for attempt in range(self.max_retries):
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response:
+ response.raise_for_status()
+
+ # GitHub returns form-encoded responses, not JSON
+ content_type = response.headers.get("content-type", "")
+ if "application/x-www-form-urlencoded" in content_type:
+ # Parse form-encoded response
+ text_response = await response.text()
+ token_response = {}
+ for pair in text_response.split("&"):
+ if "=" in pair:
+ key, value = pair.split("=", 1)
+ token_response[key] = value
+ else:
+ # Try JSON response
+ try:
+ token_response = await response.json()
+ except Exception as e:
+ logger.warning(f"Failed to parse JSON response: {e}")
+ # Fallback to text parsing
+ text_response = await response.text()
+ token_response = {"raw_response": text_response}
+
+ if "access_token" not in token_response:
+ raise OAuthError(f"No access_token in response: {token_response}")
+
+ logger.info("""Successfully exchanged authorization code for access token""")
+ return token_response["access_token"]
+
+ except aiohttp.ClientError as e:
+ logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
+ if attempt == self.max_retries - 1:
+ raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
+ await asyncio.sleep(2**attempt) # Exponential backoff
+
+ # This should never be reached due to the exception above, but needed for type safety
+ raise OAuthError("Failed to exchange code for token after all retry attempts")
+
+ async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any]) -> Dict[str, str]:
+ """Initiate Authorization Code flow and return authorization URL.
+
+ Args:
+ gateway_id: ID of the gateway being configured
+ credentials: OAuth configuration with client_id, authorization_url, etc.
+
+ Returns:
+ Dict containing authorization_url and state
+ """
+
+ # Generate state parameter for CSRF protection
+ state = self._generate_state(gateway_id)
+
+ # Store state in session/cache for validation
+ if self.token_storage:
+ await self._store_authorization_state(gateway_id, state)
+
+ # Generate authorization URL
+ auth_url, _ = self._create_authorization_url(credentials, state)
+
+ logger.info(f"Generated authorization URL for gateway {gateway_id}")
+
+ return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id}
+
+ async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
+ """Complete Authorization Code flow and store tokens.
+
+ Args:
+ gateway_id: ID of the gateway
+ code: Authorization code from callback
+ state: State parameter for CSRF validation
+ credentials: OAuth configuration
+
+ Returns:
+ Dict containing success status, user_id, and expiration info
+
+ Raises:
+ OAuthError: If state validation fails or token exchange fails
+ """
+ # Validate state parameter
+ if self.token_storage and not await self._validate_authorization_state(gateway_id, state):
+ raise OAuthError("Invalid state parameter")
+
+ # Exchange code for tokens
+ token_response = await self._exchange_code_for_tokens(credentials, code)
+
+ # Extract user information from token response
+ user_id = self._extract_user_id(token_response, credentials)
+
+ # Store tokens if storage service is available
+ if self.token_storage:
+ token_record = await self.token_storage.store_tokens(
+ gateway_id=gateway_id,
+ user_id=user_id,
+ access_token=token_response["access_token"],
+ refresh_token=token_response.get("refresh_token"),
+ expires_in=token_response.get("expires_in", 3600),
+ scopes=token_response.get("scope", "").split(),
+ )
+
+ return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None}
+ return {"success": True, "user_id": user_id, "expires_at": None}
+
+ async def get_access_token_for_user(self, gateway_id: str, user_id: str) -> Optional[str]:
+ """Get valid access token for a specific user.
+
+ Args:
+ gateway_id: ID of the gateway
+ user_id: OAuth provider user ID
+
+ Returns:
+ Valid access token or None if not available
+ """
+ if self.token_storage:
+ return await self.token_storage.get_valid_token(gateway_id, user_id)
+ return None
+
+ def _generate_state(self, gateway_id: str) -> str:
+ """Generate a unique state parameter for CSRF protection.
+
+ Args:
+ gateway_id: ID of the gateway
+
+ Returns:
+ Unique state string
+ """
+ return f"{gateway_id}_{secrets.token_urlsafe(32)}"
+
+ async def _store_authorization_state(self, gateway_id: str, state: str) -> None: # pylint: disable=unused-argument
+ """Store authorization state for validation.
+
+ Args:
+ gateway_id: ID of the gateway
+ state: State parameter to store
+ """
+ # This is a placeholder implementation
+ # In a real implementation, you would store the state in a cache or database
+ # with an expiration time for security
+ logger.debug(f"Stored authorization state for gateway {gateway_id}")
+
+ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool: # pylint: disable=unused-argument
+ """Validate authorization state parameter.
+
+ Args:
+ gateway_id: ID of the gateway
+ state: State parameter to validate
+
+ Returns:
+ True if state is valid
+ """
+ # This is a placeholder implementation
+ # In a real implementation, you would retrieve and validate the stored state
+ logger.debug(f"Validating authorization state for gateway {gateway_id}")
+ return True # Placeholder: always return True for now
+
+ def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
+ """Create authorization URL with state parameter.
+
+ Args:
+ credentials: OAuth configuration
+ state: State parameter for CSRF protection
+
+ Returns:
+ Tuple of (authorization_url, state)
+ """
+ client_id = credentials["client_id"]
+ redirect_uri = credentials["redirect_uri"]
+ authorization_url = credentials["authorization_url"]
+ scopes = credentials.get("scopes", [])
+
+ # Create OAuth2 session
+ oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
+
+ # Generate authorization URL with state for CSRF protection
+ auth_url, state = oauth.authorization_url(authorization_url, state=state)
+
+ return auth_url, state
+
+ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str) -> Dict[str, Any]:
+ """Exchange authorization code for tokens.
+
+ Args:
+ credentials: OAuth configuration
+ code: Authorization code from callback
+
+ Returns:
+ Token response dictionary
+
+ Raises:
+ OAuthError: If token exchange fails
+ """
+ client_id = credentials["client_id"]
+ client_secret = credentials["client_secret"]
+ token_url = credentials["token_url"]
+ redirect_uri = credentials["redirect_uri"]
+
+ # Decrypt client secret if it's encrypted
+ if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
+ try:
+ settings = get_settings()
+ encryption = get_oauth_encryption(settings.auth_encryption_secret)
+ decrypted_secret = encryption.decrypt_secret(client_secret)
+ if decrypted_secret:
+ client_secret = decrypted_secret
+ logger.debug("Successfully decrypted client secret")
+ else:
+ logger.warning("Failed to decrypt client secret, using encrypted version")
+ except Exception as e:
+ logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
+
+ # Prepare token exchange data
+ token_data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": redirect_uri,
+ "client_id": client_id,
+ "client_secret": client_secret,
+ }
+
+ # Exchange code for token with retries
+ for attempt in range(self.max_retries):
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response:
+ response.raise_for_status()
+
+ # GitHub returns form-encoded responses, not JSON
+ content_type = response.headers.get("content-type", "")
+ if "application/x-www-form-urlencoded" in content_type:
+ # Parse form-encoded response
+ text_response = await response.text()
+ token_response = {}
+ for pair in text_response.split("&"):
+ if "=" in pair:
+ key, value = pair.split("=", 1)
+ token_response[key] = value
+ else:
+ # Try JSON response
+ try:
+ token_response = await response.json()
+ except Exception as e:
+ logger.warning(f"Failed to parse JSON response: {e}")
+ # Fallback to text parsing
+ text_response = await response.text()
+ token_response = {"raw_response": text_response}
+
+ if "access_token" not in token_response:
+ raise OAuthError(f"No access_token in response: {token_response}")
+
+ logger.info("""Successfully exchanged authorization code for tokens""")
+ return token_response
+
+ except aiohttp.ClientError as e:
+ logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
+ if attempt == self.max_retries - 1:
+ raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
+ await asyncio.sleep(2**attempt) # Exponential backoff
+
+ # This should never be reached due to the exception above, but needed for type safety
+ raise OAuthError("Failed to exchange code for token after all retry attempts")
+
+ def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str: # pylint: disable=unused-argument
+ """Extract user ID from token response.
+
+ Args:
+ token_response: Response from token exchange
+ credentials: OAuth configuration
+
+ Returns:
+ User ID string
+ """
+ # This is a placeholder implementation
+ # In a real implementation, you might:
+ # 1. Extract user_id from the token response if provided
+ # 2. Make a request to the OAuth provider's user info endpoint
+ # 3. Use a default identifier based on the gateway
+
+ # For now, use a placeholder user ID
+ # In production, you should implement proper user ID extraction
+ return f"user_{credentials.get('client_id', 'unknown')}"
+
+
+class OAuthError(Exception):
+ """OAuth-related errors."""
diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py
new file mode 100644
index 000000000..66165708c
--- /dev/null
+++ b/mcpgateway/services/token_storage_service.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+"""OAuth Token Storage Service for MCP Gateway.
+
+This module handles the storage, retrieval, and management of OAuth access and refresh tokens
+for Authorization Code flow implementations.
+"""
+
+# Standard
+from datetime import datetime, timedelta
+import logging
+from typing import Any, Dict, List, Optional
+
+# Third-Party
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+# First-Party
+from mcpgateway.config import get_settings
+from mcpgateway.db import OAuthToken
+from mcpgateway.services.oauth_manager import OAuthError
+from mcpgateway.utils.oauth_encryption import get_oauth_encryption
+
+logger = logging.getLogger(__name__)
+
+
+class TokenStorageService:
+ """Manages OAuth token storage and retrieval."""
+
+ def __init__(self, db: Session):
+ """Initialize Token Storage Service.
+
+ Args:
+ db: Database session
+ """
+ self.db = db
+ try:
+ settings = get_settings()
+ self.encryption = get_oauth_encryption(settings.auth_encryption_secret)
+ except (ImportError, AttributeError):
+ logger.warning("OAuth encryption not available, using plain text storage")
+ self.encryption = None
+
+ async def store_tokens(self, gateway_id: str, user_id: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken:
+ """Store OAuth tokens for a gateway-user combination.
+
+ Args:
+ gateway_id: ID of the gateway
+ user_id: OAuth provider user ID
+ access_token: Access token from OAuth provider
+ refresh_token: Refresh token from OAuth provider (optional)
+ expires_in: Token expiration time in seconds
+ scopes: List of OAuth scopes granted
+
+ Returns:
+ OAuthToken record
+
+ Raises:
+ OAuthError: If token storage fails
+ """
+ try:
+ # Encrypt sensitive tokens if encryption is available
+ encrypted_access = access_token
+ encrypted_refresh = refresh_token
+
+ if self.encryption:
+ encrypted_access = self.encryption.encrypt_secret(access_token)
+ if refresh_token:
+ encrypted_refresh = self.encryption.encrypt_secret(refresh_token)
+
+ # Calculate expiration
+ expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
+
+ # Create or update token record
+ token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.user_id == user_id)).scalar_one_or_none()
+
+ if token_record:
+ # Update existing record
+ token_record.access_token = encrypted_access
+ token_record.refresh_token = encrypted_refresh
+ token_record.expires_at = expires_at
+ token_record.scopes = scopes
+ token_record.updated_at = datetime.utcnow()
+ logger.info(f"Updated OAuth tokens for gateway {gateway_id}, user {user_id}")
+ else:
+ # Create new record
+ token_record = OAuthToken(gateway_id=gateway_id, user_id=user_id, access_token=encrypted_access, refresh_token=encrypted_refresh, expires_at=expires_at, scopes=scopes)
+ self.db.add(token_record)
+ logger.info(f"Stored new OAuth tokens for gateway {gateway_id}, user {user_id}")
+
+ self.db.commit()
+ return token_record
+
+ except Exception as e:
+ self.db.rollback()
+ logger.error(f"Failed to store OAuth tokens: {str(e)}")
+ raise OAuthError(f"Token storage failed: {str(e)}")
+
+ async def get_valid_token(self, gateway_id: str, user_id: str, threshold_seconds: int = 300) -> Optional[str]:
+ """Get a valid access token, refreshing if necessary.
+
+ Args:
+ gateway_id: ID of the gateway
+ user_id: OAuth provider user ID
+ threshold_seconds: Seconds before expiry to consider token expired
+
+ Returns:
+ Valid access token or None if no valid token available
+ """
+ try:
+ token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.user_id == user_id)).scalar_one_or_none()
+
+ if not token_record:
+ logger.debug(f"No OAuth tokens found for gateway {gateway_id}, user {user_id}")
+ return None
+
+ # Check if token is expired or near expiration
+ if self._is_token_expired(token_record, threshold_seconds):
+ logger.info(f"OAuth token expired for gateway {gateway_id}, user {user_id}")
+ if token_record.refresh_token:
+ # Attempt to refresh token
+ new_token = await self._refresh_access_token(token_record)
+ if new_token:
+ return new_token
+ return None
+
+ # Decrypt and return valid token
+ if self.encryption:
+ return self.encryption.decrypt_secret(token_record.access_token)
+ return token_record.access_token
+
+ except Exception as e:
+ logger.error(f"Failed to retrieve OAuth token: {str(e)}")
+ return None
+
+ async def get_any_valid_token(self, gateway_id: str, threshold_seconds: int = 300) -> Optional[str]:
+ """Get any valid access token for a gateway, regardless of user.
+
+ Args:
+ gateway_id: ID of the gateway
+ threshold_seconds: Seconds before expiry to consider token expired
+
+ Returns:
+ Valid access token or None if no valid token available
+ """
+ try:
+ # Get any token for this gateway
+ token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id)).scalar_one_or_none()
+
+ if not token_record:
+ logger.debug(f"No OAuth tokens found for gateway {gateway_id}")
+ return None
+
+ # Check if token is expired or near expiration
+ if self._is_token_expired(token_record, threshold_seconds):
+ logger.info(f"OAuth token expired for gateway {gateway_id}")
+ if token_record.refresh_token:
+ # Attempt to refresh token
+ new_token = await self._refresh_access_token(token_record)
+ if new_token:
+ return new_token
+ return None
+
+ # Decrypt and return valid token
+ if self.encryption:
+ return self.encryption.decrypt_secret(token_record.access_token)
+ return token_record.access_token
+
+ except Exception as e:
+ logger.error(f"Failed to retrieve OAuth token: {str(e)}")
+ return None
+
+ async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]:
+ """Refresh an expired access token using refresh token.
+
+ Args:
+ token_record: OAuth token record to refresh
+
+ Returns:
+ New access token or None if refresh failed
+ """
+ try:
+ # This is a placeholder for token refresh implementation
+ # In a real implementation, you would:
+ # 1. Decrypt the refresh token
+ # 2. Make a request to the OAuth provider's token endpoint
+ # 3. Update the stored tokens with the new response
+ # 4. Return the new access token
+
+ logger.info(f"Token refresh not yet implemented for gateway {token_record.gateway_id}")
+ return None
+
+ except Exception as e:
+ logger.error(f"Failed to refresh OAuth token: {str(e)}")
+ return None
+
+ def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool:
+ """Check if token is expired or near expiration.
+
+ Args:
+ token_record: OAuth token record to check
+ threshold_seconds: Seconds before expiry to consider token expired
+
+ Returns:
+ True if token is expired or near expiration
+ """
+ if not token_record.expires_at:
+ return True
+
+ return datetime.utcnow() + timedelta(seconds=threshold_seconds) >= token_record.expires_at
+
+ async def get_token_info(self, gateway_id: str, user_id: str) -> Optional[Dict[str, Any]]:
+ """Get information about stored OAuth tokens.
+
+ Args:
+ gateway_id: ID of the gateway
+ user_id: OAuth provider user ID
+
+ Returns:
+ Token information dictionary or None if not found
+ """
+ try:
+ token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.user_id == user_id)).scalar_one_or_none()
+
+ if not token_record:
+ return None
+
+ return {
+ "user_id": token_record.user_id,
+ "token_type": token_record.token_type,
+ "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None,
+ "scopes": token_record.scopes,
+ "created_at": token_record.created_at.isoformat(),
+ "updated_at": token_record.updated_at.isoformat(),
+ "is_expired": self._is_token_expired(token_record, 0),
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to get token info: {str(e)}")
+ return None
+
+ async def revoke_user_tokens(self, gateway_id: str, user_id: str) -> bool:
+ """Revoke OAuth tokens for a specific user.
+
+ Args:
+ gateway_id: ID of the gateway
+ user_id: OAuth provider user ID
+
+ Returns:
+ True if tokens were revoked successfully
+ """
+ try:
+ token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.user_id == user_id)).scalar_one_or_none()
+
+ if token_record:
+ self.db.delete(token_record)
+ self.db.commit()
+ logger.info(f"Revoked OAuth tokens for gateway {gateway_id}, user {user_id}")
+ return True
+
+ return False
+
+ except Exception as e:
+ self.db.rollback()
+ logger.error(f"Failed to revoke OAuth tokens: {str(e)}")
+ return False
+
+ async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int:
+ """Clean up expired OAuth tokens older than specified days.
+
+ Args:
+ max_age_days: Maximum age of tokens to keep
+
+ Returns:
+ Number of tokens cleaned up
+ """
+ try:
+ cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
+
+ expired_tokens = self.db.execute(select(OAuthToken).where(OAuthToken.expires_at < cutoff_date)).scalars().all()
+
+ count = len(expired_tokens)
+ for token in expired_tokens:
+ self.db.delete(token)
+
+ self.db.commit()
+ logger.info(f"Cleaned up {count} expired OAuth tokens")
+ return count
+
+ except Exception as e:
+ self.db.rollback()
+ logger.error(f"Failed to cleanup expired tokens: {str(e)}")
+ return 0
diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py
index 738c4a372..f7f277aac 100644
--- a/mcpgateway/services/tool_service.py
+++ b/mcpgateway/services/tool_service.py
@@ -44,6 +44,7 @@
from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload
from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer
from mcpgateway.services.logging_service import LoggingService
+from mcpgateway.services.oauth_manager import OAuthManager
from mcpgateway.utils.create_slug import slugify
from mcpgateway.utils.metrics_common import build_top_performers
from mcpgateway.utils.passthrough_headers import get_passthrough_headers
@@ -167,6 +168,10 @@ def __init__(self) -> None:
self._event_subscribers: List[asyncio.Queue] = []
self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
self._plugin_manager: PluginManager | None = PluginManager() if settings.plugins_enabled else None
+ self.oauth_manager = OAuthManager(
+ request_timeout=int(settings.oauth_request_timeout if hasattr(settings, "oauth_request_timeout") else 30),
+ max_retries=int(settings.oauth_max_retries if hasattr(settings, "oauth_max_retries") else 3),
+ )
async def initialize(self) -> None:
"""Initialize the service.
@@ -731,10 +736,19 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r
# headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers)
headers = tool.headers or {}
if tool.integration_type == "REST":
- credentials = decode_auth(tool.auth_value)
- # Filter out empty header names/values to avoid "Illegal header name" errors
- filtered_credentials = {k: v for k, v in credentials.items() if k and v}
- headers.update(filtered_credentials)
+ # Handle OAuth authentication for REST tools
+ if tool.auth_type == "oauth" and hasattr(tool, "oauth_config") and tool.oauth_config:
+ try:
+ access_token = await self.oauth_manager.get_access_token(tool.oauth_config)
+ headers["Authorization"] = f"Bearer {access_token}"
+ except Exception as e:
+ logger.error(f"Failed to obtain OAuth access token for tool {tool.name}: {e}")
+ raise ToolInvocationError(f"OAuth authentication failed: {str(e)}")
+ else:
+ credentials = decode_auth(tool.auth_value)
+ # Filter out empty header names/values to avoid "Illegal header name" errors
+ filtered_credentials = {k: v for k, v in credentials.items() if k and v}
+ headers.update(filtered_credentials)
# Only call get_passthrough_headers if we actually have request headers to pass through
if request_headers:
@@ -795,7 +809,43 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r
elif tool.integration_type == "MCP":
transport = tool.request_type.lower()
gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none()
- headers = decode_auth(gateway.auth_value if gateway else None)
+
+ # Handle OAuth authentication for the gateway
+ if gateway and gateway.auth_type == "oauth" and gateway.oauth_config:
+ grant_type = gateway.oauth_config.get("grant_type", "client_credentials")
+
+ if grant_type == "authorization_code":
+ # For Authorization Code flow, try to get stored tokens
+ try:
+ # First-Party
+ from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
+
+ token_storage = TokenStorageService(db)
+
+ # Try to get a valid token for any user (for now, we'll use a placeholder)
+ # In a real implementation, you might want to specify which user's tokens to use
+ access_token = await token_storage.get_any_valid_token(gateway.id)
+
+ if access_token:
+ headers = {"Authorization": f"Bearer {access_token}"}
+ else:
+ # No valid token available - user needs to complete OAuth flow
+ raise ToolInvocationError(
+ f"OAuth Authorization Code flow requires user consent. " f"Please complete the OAuth flow for gateway '{gateway.name}' before using tools."
+ )
+ except Exception as e:
+ logger.error(f"Failed to obtain stored OAuth token for gateway {gateway.name}: {e}")
+ raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}")
+ else:
+ # For Client Credentials flow, get token directly
+ try:
+ access_token = await self.oauth_manager.get_access_token(gateway.oauth_config)
+ headers = {"Authorization": f"Bearer {access_token}"}
+ except Exception as e:
+ logger.error(f"Failed to obtain OAuth access token for gateway {gateway.name}: {e}")
+ raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}")
+ else:
+ headers = decode_auth(gateway.auth_value if gateway else None)
# Get combined headers including gateway auth and passthrough
if request_headers:
diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js
index 1fb196aa8..26c0993c1 100644
--- a/mcpgateway/static/admin.js
+++ b/mcpgateway/static/admin.js
@@ -5462,6 +5462,51 @@ async function handleGatewayFormSubmit(e) {
}
}
+ // Handle OAuth configuration
+ const authType = formData.get("auth_type");
+ if (authType === "oauth") {
+ const oauthConfig = {
+ grant_type: formData.get("oauth_grant_type"),
+ client_id: formData.get("oauth_client_id"),
+ client_secret: formData.get("oauth_client_secret"),
+ token_url: formData.get("oauth_token_url"),
+ scopes: formData.get("oauth_scopes")
+ ? formData
+ .get("oauth_scopes")
+ .split(" ")
+ .filter((s) => s.trim())
+ : [],
+ };
+
+ // Add authorization code specific fields
+ if (oauthConfig.grant_type === "authorization_code") {
+ oauthConfig.authorization_url = formData.get(
+ "oauth_authorization_url",
+ );
+ oauthConfig.redirect_uri = formData.get("oauth_redirect_uri");
+
+ // Add token management options
+ oauthConfig.token_management = {
+ store_tokens: formData.get("oauth_store_tokens") === "on",
+ auto_refresh: formData.get("oauth_auto_refresh") === "on",
+ refresh_threshold_seconds: 300,
+ };
+ }
+
+ // Remove individual OAuth fields and add as oauth_config
+ formData.delete("oauth_grant_type");
+ formData.delete("oauth_client_id");
+ formData.delete("oauth_client_secret");
+ formData.delete("oauth_token_url");
+ formData.delete("oauth_scopes");
+ formData.delete("oauth_authorization_url");
+ formData.delete("oauth_redirect_uri");
+ formData.delete("oauth_store_tokens");
+ formData.delete("oauth_auto_refresh");
+
+ formData.append("oauth_config", JSON.stringify(oauthConfig));
+ }
+
const response = await fetchWithTimeout(
`${window.ROOT_PATH}/admin/gateways`,
{
@@ -6537,6 +6582,21 @@ function setupFormHandlers() {
const gatewayForm = safeGetElement("add-gateway-form");
if (gatewayForm) {
gatewayForm.addEventListener("submit", handleGatewayFormSubmit);
+
+ // Add OAuth authentication type change handler
+ const authTypeField = safeGetElement("auth-type-gw");
+ if (authTypeField) {
+ authTypeField.addEventListener("change", handleAuthTypeChange);
+ }
+
+ // Add OAuth grant type change handler
+ const oauthGrantTypeField = safeGetElement("oauth-grant-type-gw");
+ if (oauthGrantTypeField) {
+ oauthGrantTypeField.addEventListener(
+ "change",
+ handleOAuthGrantTypeChange,
+ );
+ }
}
const resourceForm = safeGetElement("add-resource-form");
@@ -6620,6 +6680,87 @@ function setupFormHandlers() {
}
}
+function handleAuthTypeChange() {
+ const authType = this.value;
+ const basicFields = safeGetElement("auth-basic-fields-gw");
+ const bearerFields = safeGetElement("auth-bearer-fields-gw");
+ const headersFields = safeGetElement("auth-headers-fields-gw");
+ const oauthFields = safeGetElement("auth-oauth-fields-gw");
+
+ // Hide all auth sections first
+ if (basicFields) {
+ basicFields.style.display = "none";
+ }
+ if (bearerFields) {
+ bearerFields.style.display = "none";
+ }
+ if (headersFields) {
+ headersFields.style.display = "none";
+ }
+ if (oauthFields) {
+ oauthFields.style.display = "none";
+ }
+
+ // Show the appropriate section
+ switch (authType) {
+ case "basic":
+ if (basicFields) {
+ basicFields.style.display = "block";
+ }
+ break;
+ case "bearer":
+ if (bearerFields) {
+ bearerFields.style.display = "block";
+ }
+ break;
+ case "authheaders":
+ if (headersFields) {
+ headersFields.style.display = "block";
+ }
+ break;
+ case "oauth":
+ if (oauthFields) {
+ oauthFields.style.display = "block";
+ }
+ break;
+ default:
+ // No auth - keep everything hidden
+ break;
+ }
+}
+
+function handleOAuthGrantTypeChange() {
+ const grantType = this.value;
+ const authCodeFields = safeGetElement("oauth-auth-code-fields-gw");
+
+ if (authCodeFields) {
+ if (grantType === "authorization_code") {
+ authCodeFields.style.display = "block";
+
+ // Make authorization code specific fields required
+ const requiredFields =
+ authCodeFields.querySelectorAll('input[type="url"]');
+ requiredFields.forEach((field) => {
+ field.required = true;
+ });
+
+ // Show additional validation for required fields
+ console.log(
+ "Authorization Code flow selected - additional fields are now required",
+ );
+ } else {
+ authCodeFields.style.display = "none";
+
+ // Remove required validation for hidden fields
+ const requiredFields =
+ authCodeFields.querySelectorAll('input[type="url"]');
+ requiredFields.forEach((field) => {
+ field.required = false;
+ });
+ }
+ }
+}
+
function setupSchemaModeHandlers() {
const schemaModeRadios = document.getElementsByName("schema_input_mode");
const uiBuilderDiv = safeGetElement("ui-builder");
@@ -7563,6 +7704,67 @@ window.removeAuthHeader = removeAuthHeader;
window.updateAuthHeadersJSON = updateAuthHeadersJSON;
window.loadAuthHeaders = loadAuthHeaders;
+/**
+ * Fetch tools from MCP server after OAuth completion for Authorization Code flow
+ * @param {string} gatewayId - ID of the gateway to fetch tools for
+ * @param {string} gatewayName - Name of the gateway for display purposes
+ */
+async function fetchToolsForGateway(gatewayId, gatewayName) {
+ const button = document.getElementById(`fetch-tools-${gatewayId}`);
+ if (!button) {
+ return;
+ }
+
+ // Disable button and show loading state
+ button.disabled = true;
+ button.textContent = "⏳ Fetching...";
+ button.className =
+ "inline-block bg-yellow-600 hover:bg-yellow-700 text-white px-3 py-1 rounded text-sm mr-2";
+
+ try {
+ const response = await fetch(`/oauth/fetch-tools/${gatewayId}`, {
+ method: "POST",
+ });
+
+ const result = await response.json();
+
+ if (response.ok) {
+ // Success
+ button.textContent = "✅ Tools Fetched";
+ button.className =
+ "inline-block bg-green-600 hover:bg-green-700 text-white px-3 py-1 rounded text-sm mr-2";
+
+ // Show success message
+ showSuccessMessage(
+ `Successfully fetched ${result.tools_created} tools from ${gatewayName}`,
+ );
+
+ // Refresh the page to show the new tools
+ setTimeout(() => {
+ window.location.reload();
+ }, 2000);
+ } else {
+ throw new Error(result.detail || "Failed to fetch tools");
+ }
+ } catch (error) {
+ console.error("Failed to fetch tools:", error);
+
+ // Show error state
+ button.textContent = "❌ Retry";
+ button.className =
+ "inline-block bg-red-600 hover:bg-red-700 text-white px-3 py-1 rounded text-sm mr-2";
+ button.disabled = false;
+
+ // Show error message
+ showErrorMessage(
+ `Failed to fetch tools from ${gatewayName}: ${error.message}`,
+ );
+ }
+}
+
+// Expose fetch tools function to global scope
+window.fetchToolsForGateway = fetchToolsForGateway;
+
console.log("🛡️ ContextForge MCP Gateway admin.js initialized");
// ===================================================================
diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html
index a0c1b092c..69ca33395 100644
--- a/mcpgateway/templates/admin.html
+++ b/mcpgateway/templates/admin.html
@@ -2783,6 +2783,36 @@
>
Test
+
+
+
+
+
+
+
+
+ {% if gateway.authType == 'oauth' %}
+
+ 🔐 Authorize
+
+
+
+
+ {% endif %}
{% if gateway.enabled %}