From 72fa239a2b0144e000fdb0e860367645c8ae125c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Torregrosa=20P=C3=A1ez?= <20774994+pablotp@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:43:32 +0200 Subject: [PATCH 1/4] feat: Add authentication to the endpoints --- README.md | 44 ++++++++++ src/mcp_proxy/__main__.py | 9 +++ src/mcp_proxy/auth.py | 54 +++++++++++++ src/mcp_proxy/mcp_server.py | 6 ++ tests/test_auth_simple.py | 155 ++++++++++++++++++++++++++++++++++++ tests/test_mcp_server.py | 46 +++++++++++ 6 files changed, 314 insertions(+) create mode 100644 src/mcp_proxy/auth.py create mode 100644 tests/test_auth_simple.py diff --git a/README.md b/README.md index ad9ce87..6c3700f 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ - [2.1 Configuration](#21-configuration) - [2.2 Example usage](#22-example-usage) - [Named Servers](#named-servers) + - [Authentication](#authentication) - [Installation](#installation) - [Installing via Smithery](#installing-via-smithery) - [Installing via PyPI](#installing-via-pypi) @@ -126,6 +127,7 @@ Arguments | `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp | | `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment | | `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-origin "\*" | +| `--api-key` | No | API key for authentication. Can also be set via MCP_PROXY_API_KEY env var. | --api-key YOUR_SECRET_KEY | | `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless | | `--named-server NAME COMMAND_STRING` | No | Defines a named stdio server. | --named-server fetch 'uvx mcp-server-fetch' | | `--named-server-config FILE_PATH` | No | Path to a JSON file defining named stdio servers. | --named-server-config /path/to/servers.json | @@ -211,6 +213,46 @@ The JSON file should follow this structure: - `enabled`: (Optional) If `false`, this server definition will be skipped. Defaults to `true`. - `timeout` and `transportType`: These fields are present in standard MCP client configurations but are currently **ignored** by `mcp-proxy` when loading named servers. The transport type is implicitly "stdio". +## Authentication + +The MCP proxy supports optional API key authentication to protect your endpoints. When enabled, all requests to `/sse` and `/mcp` endpoints (including named server paths like `/servers/*/sse` and `/servers/*/mcp`) require a valid API key. + +### Configuration + +You can configure authentication in two ways: + +1. **Command-line argument**: `--api-key YOUR_SECRET_KEY` +2. **Environment variable**: `MCP_PROXY_API_KEY=YOUR_SECRET_KEY` + +If no API key is configured, authentication is disabled by default (backward compatible). + +### Usage + +When authentication is enabled, clients must include the API key in their requests using the `X-API-Key` header (case-insensitive): + +```bash +# Example: Connecting to a protected SSE endpoint +curl -H "X-API-Key: YOUR_SECRET_KEY" http://localhost:8080/sse + +# Example: Starting a protected server +mcp-proxy --port 8080 --api-key YOUR_SECRET_KEY uvx mcp-server-fetch + +# Example: Using environment variable +export MCP_PROXY_API_KEY=YOUR_SECRET_KEY +mcp-proxy --port 8080 uvx mcp-server-fetch +``` + +### Protected Endpoints + +- `/sse` and `/servers/*/sse` - SSE endpoints +- `/mcp` and `/servers/*/mcp` - MCP endpoints +- `/messages/*` - Message endpoints + +### Unprotected Endpoints + +- `/status` - Health check endpoint (always accessible) +- OPTIONS requests - CORS preflight requests + ## Installation ### Installing via Smithery @@ -361,6 +403,7 @@ SSE server options: --sse-host SSE_HOST (deprecated) Same as --host --allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...] Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. + --api-key API_KEY API key for authentication. Can also be set via MCP_PROXY_API_KEY env var. If not provided, authentication is disabled. Examples: mcp-proxy http://localhost:8080/sse @@ -372,6 +415,7 @@ Examples: mcp-proxy --port 8080 --named-server-config /path/to/servers.json -- my-default-command --arg1 mcp-proxy --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE -- my-default-command mcp-proxy --port 8080 --allow-origin='*' -- my-default-command + mcp-proxy --port 8080 --api-key YOUR_SECRET_KEY -- my-default-command ``` ### Example config file diff --git a/src/mcp_proxy/__main__.py b/src/mcp_proxy/__main__.py index b9fdc2f..eaf4201 100644 --- a/src/mcp_proxy/__main__.py +++ b/src/mcp_proxy/__main__.py @@ -201,6 +201,14 @@ def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None: "Default is no CORS allowed." ), ) + mcp_server_group.add_argument( + "--api-key", + default=os.getenv("MCP_PROXY_API_KEY"), + help=( + "API key for authentication. Can also be set via MCP_PROXY_API_KEY env var. " + "If not provided, authentication is disabled." + ), + ) def _setup_logging(*, debug: bool) -> logging.Logger: @@ -335,6 +343,7 @@ def _create_mcp_settings(args_parsed: argparse.Namespace) -> MCPServerSettings: stateless=args_parsed.stateless, allow_origins=args_parsed.allow_origin if len(args_parsed.allow_origin) > 0 else None, log_level="DEBUG" if args_parsed.debug else "INFO", + api_key=args_parsed.api_key, ) diff --git a/src/mcp_proxy/auth.py b/src/mcp_proxy/auth.py new file mode 100644 index 0000000..f3d86a9 --- /dev/null +++ b/src/mcp_proxy/auth.py @@ -0,0 +1,54 @@ +"""Simple authentication middleware for MCP proxy.""" + +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +logger = logging.getLogger(__name__) + + +class AuthMiddleware(BaseHTTPMiddleware): + """Simple API key authentication middleware.""" + + def __init__(self, app, api_key: str | None = None) -> None: + """Initialize middleware with optional API key.""" + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next) -> Response: + """Check API key for protected endpoints.""" + # Skip auth if no API key configured + if not self.api_key: + return await call_next(request) + + # Allow OPTIONS (CORS preflight) and /status endpoint + if request.method == "OPTIONS" or request.url.path == "/status": + return await call_next(request) + + # Check if path needs protection (/sse, /mcp, /messages, /servers/*/sse, /servers/*/mcp) + path = request.url.path + needs_auth = ( + path.startswith("/sse") or + path.startswith("/mcp") or + path.startswith("/messages") or + "/sse" in path or + "/mcp" in path + ) + + if not needs_auth: + return await call_next(request) + + # Check for API key in headers (case-insensitive) + api_key = request.headers.get("x-api-key", "") + + if api_key != self.api_key: + logger.warning("Auth failed for %s %s", request.method, path) + return JSONResponse( + {"error": "Unauthorized"}, + status_code=401 + ) + + return await call_next(request) + diff --git a/src/mcp_proxy/mcp_server.py b/src/mcp_proxy/mcp_server.py index a39abb5..0d553b3 100644 --- a/src/mcp_proxy/mcp_server.py +++ b/src/mcp_proxy/mcp_server.py @@ -21,6 +21,7 @@ from starlette.routing import BaseRoute, Mount, Route from starlette.types import Receive, Scope, Send +from .auth import AuthMiddleware from .proxy_server import create_proxy_server logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class MCPServerSettings: stateless: bool = False allow_origins: list[str] | None = None log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + api_key: str | None = None # To store last activity for multiple servers if needed, though status endpoint is global for now. @@ -169,6 +171,10 @@ async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]: return middleware: list[Middleware] = [] + if mcp_settings.api_key: + middleware.append( + Middleware(AuthMiddleware, api_key=mcp_settings.api_key), + ) if mcp_settings.allow_origins: middleware.append( Middleware( diff --git a/tests/test_auth_simple.py b/tests/test_auth_simple.py new file mode 100644 index 0000000..0b24b3a --- /dev/null +++ b/tests/test_auth_simple.py @@ -0,0 +1,155 @@ +"""Simplified tests for authentication middleware.""" + +import pytest +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.responses import JSONResponse +from starlette.routing import Route +from starlette.testclient import TestClient + +from mcp_proxy.auth import AuthMiddleware + + +async def dummy_endpoint(request): + """Simple endpoint for testing.""" + return JSONResponse({"message": "success"}) + + +async def status_endpoint(request): + """Status endpoint.""" + return JSONResponse({"status": "ok"}) + + +def create_app_without_auth(): + """Create app without authentication.""" + routes = [ + Route("/sse", dummy_endpoint), + Route("/mcp/test", dummy_endpoint), + Route("/messages/test", dummy_endpoint), + Route("/status", status_endpoint), + Route("/other", dummy_endpoint), + ] + return Starlette(routes=routes) + + +def create_app_with_auth(): + """Create app with authentication.""" + routes = [ + Route("/sse", dummy_endpoint), + Route("/mcp/test", dummy_endpoint), + Route("/messages/test", dummy_endpoint), + Route("/status", status_endpoint), + Route("/other", dummy_endpoint), + Route("/servers/test/sse", dummy_endpoint), + Route("/servers/test/mcp", dummy_endpoint), + ] + middleware = [Middleware(AuthMiddleware, api_key="test-api-key")] + return Starlette(routes=routes, middleware=middleware) + + +def test_no_auth_allows_all(): + """Test that all requests work without authentication configured.""" + app = create_app_without_auth() + with TestClient(app) as client: + assert client.get("/sse").status_code == 200 + assert client.get("/mcp/test").status_code == 200 + assert client.get("/status").status_code == 200 + + +def test_auth_blocks_protected_endpoints(): + """Test that protected endpoints are blocked without API key.""" + app = create_app_with_auth() + with TestClient(app) as client: + response = client.get("/sse") + assert response.status_code == 401 + assert response.json() == {"error": "Unauthorized"} + + response = client.get("/mcp/test") + assert response.status_code == 401 + + response = client.get("/messages/test") + assert response.status_code == 401 + + +def test_auth_allows_with_key(): + """Test that requests work with correct API key.""" + app = create_app_with_auth() + with TestClient(app) as client: + headers = {"x-api-key": "test-api-key"} + + response = client.get("/sse", headers=headers) + assert response.status_code == 200 + assert response.json() == {"message": "success"} + + response = client.get("/mcp/test", headers=headers) + assert response.status_code == 200 + + +def test_auth_blocks_wrong_key(): + """Test that requests are blocked with wrong API key.""" + app = create_app_with_auth() + with TestClient(app) as client: + headers = {"x-api-key": "wrong-key"} + + response = client.get("/sse", headers=headers) + assert response.status_code == 401 + + +def test_status_not_protected(): + """Test that /status endpoint is not protected.""" + app = create_app_with_auth() + with TestClient(app) as client: + response = client.get("/status") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_other_endpoints_not_protected(): + """Test that non-SSE/MCP endpoints are not protected.""" + app = create_app_with_auth() + with TestClient(app) as client: + response = client.get("/other") + assert response.status_code == 200 + + +def test_options_allowed(): + """Test that OPTIONS requests are allowed without auth.""" + app = create_app_with_auth() + with TestClient(app) as client: + response = client.options("/sse") + assert response.status_code != 401 + + +def test_case_insensitive_header(): + """Test that API key header is case-insensitive.""" + app = create_app_with_auth() + with TestClient(app) as client: + # Different case variations + headers = {"X-API-KEY": "test-api-key"} + response = client.get("/sse", headers=headers) + assert response.status_code == 200 + + headers = {"X-Api-Key": "test-api-key"} + response = client.get("/sse", headers=headers) + assert response.status_code == 200 + + +def test_named_servers_protected(): + """Test that named server endpoints are protected.""" + app = create_app_with_auth() + with TestClient(app) as client: + # Without auth + response = client.get("/servers/test/sse") + assert response.status_code == 401 + + response = client.get("/servers/test/mcp") + assert response.status_code == 401 + + # With auth + headers = {"x-api-key": "test-api-key"} + response = client.get("/servers/test/sse", headers=headers) + assert response.status_code == 200 + + response = client.get("/servers/test/mcp", headers=headers) + assert response.status_code == 200 + diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index b568551..38fc6f1 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -621,6 +621,51 @@ async def test_run_mcp_server_exception_handling( assert "Connection failed" in str(e) # noqa: PT017 +async def test_run_mcp_server_with_authentication( + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with authentication enabled.""" + auth_settings = MCPServerSettings( + bind_host="127.0.0.1", + port=8080, + api_key="test-secret-key", + ) + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("mcp_proxy.mcp_server.Starlette") as mock_starlette, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(auth_settings, mock_stdio_params, {}) + + # Verify Starlette was called with AuthMiddleware + mock_starlette.assert_called_once() + call_args = mock_starlette.call_args + middleware = call_args.kwargs["middleware"] + + assert len(middleware) == 1 + assert middleware[0].cls.__name__ == "AuthMiddleware" + assert middleware[0].kwargs == {"api_key": "test-secret-key"} + + async def test_run_mcp_server_both_default_and_named_servers( mock_settings: MCPServerSettings, mock_stdio_params: StdioServerParameters, @@ -672,3 +717,4 @@ async def test_run_mcp_server_both_default_and_named_servers( ) mock_server_instance.serve.assert_called_once() + From f4dd492b41ed1185746a9875a1c44c7a6841b702 Mon Sep 17 00:00:00 2001 From: Baptiste Fontaine Date: Mon, 22 Sep 2025 07:53:19 +0200 Subject: [PATCH 2/4] feat: add --log-level option (#102) Fixes https://github.com/sparfenyuk/mcp-proxy/issues/95 This adds the option `--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}` to define the log level. If both `--debug` and `--log-level` are provided, the former takes precedence. --- README.md | 5 +++-- src/mcp_proxy/__main__.py | 21 ++++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6c3700f..d22aac5 100644 --- a/README.md +++ b/README.md @@ -359,7 +359,7 @@ services: ```bash usage: mcp-proxy [-h] [--version] [-H KEY VALUE] [--transport {sse,streamablehttp}] [-e KEY VALUE] [--cwd CWD] - [--pass-environment | --no-pass-environment] [--debug | --no-debug] + [--pass-environment | --no-pass-environment] [--log-level LEVEL] [--debug | --no-debug] [--named-server NAME COMMAND_STRING] [--named-server-config FILE_PATH] [--port PORT] [--host HOST] [--stateless | --no-stateless] [--sse-port SSE_PORT] @@ -388,7 +388,8 @@ stdio client options: --cwd CWD The working directory to use when spawning the default server process. Named servers inherit the proxy's CWD. --pass-environment, --no-pass-environment Pass through all environment variables when spawning all server processes. - --debug, --no-debug Enable debug mode with detailed logging output. + --log-level LEVEL Set the log level. Default is INFO. + --debug, --no-debug Enable debug mode with detailed logging output. Equivalent to --log-level DEBUG. If both --debug and --log-level are provided, --debug takes precedence. --named-server NAME COMMAND_STRING Define a named stdio server. NAME is for the URL path /servers/NAME/. COMMAND_STRING is a single string with the command and its arguments (e.g., 'uvx mcp-server-fetch --timeout 10'). These servers inherit the proxy's CWD and environment from --pass-environment. --named-server-config FILE_PATH diff --git a/src/mcp_proxy/__main__.py b/src/mcp_proxy/__main__.py index eaf4201..c8df814 100644 --- a/src/mcp_proxy/__main__.py +++ b/src/mcp_proxy/__main__.py @@ -131,10 +131,21 @@ def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None: help="Pass through all environment variables when spawning all server processes.", default=False, ) + stdio_client_options.add_argument( + "--log-level", + type=str, + default="INFO", + metavar="LEVEL", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the log level. Default is INFO.", + ) stdio_client_options.add_argument( "--debug", action=argparse.BooleanOptionalAction, - help="Enable debug mode with detailed logging output.", + help=( + "Enable debug mode with detailed logging output. Equivalent to --log-level DEBUG. " + "If both --debug and --log-level are provided, --debug takes precedence." + ), default=False, ) stdio_client_options.add_argument( @@ -211,10 +222,10 @@ def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None: ) -def _setup_logging(*, debug: bool) -> logging.Logger: +def _setup_logging(*, level: str, debug: bool) -> logging.Logger: """Set up logging configuration and return the logger.""" logging.basicConfig( - level=logging.DEBUG if debug else logging.INFO, + level=logging.DEBUG if debug else level, format="[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s] %(message)s", ) return logging.getLogger(__name__) @@ -342,7 +353,7 @@ def _create_mcp_settings(args_parsed: argparse.Namespace) -> MCPServerSettings: port=args_parsed.port if args_parsed.port is not None else args_parsed.sse_port, stateless=args_parsed.stateless, allow_origins=args_parsed.allow_origin if len(args_parsed.allow_origin) > 0 else None, - log_level="DEBUG" if args_parsed.debug else "INFO", + log_level="DEBUG" if args_parsed.debug else args_parsed.log_level, api_key=args_parsed.api_key, ) @@ -351,7 +362,7 @@ def main() -> None: """Start the client using asyncio.""" parser = _setup_argument_parser() args_parsed = parser.parse_args() - logger = _setup_logging(debug=args_parsed.debug) + logger = _setup_logging(level=args_parsed.log_level, debug=args_parsed.debug) # Validate required arguments if ( From 2f8f658e10388e191f991db5dc34c571dae7885a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Torregrosa=20P=C3=A1ez?= <20774994+pablotp@users.noreply.github.com> Date: Mon, 22 Sep 2025 11:05:21 +0200 Subject: [PATCH 3/4] fix: Address PR feedback --- src/mcp_proxy/auth.py | 19 ++-- src/mcp_proxy/mcp_server.py | 202 +++++++++++++++++++++--------------- tests/test_auth_simple.py | 69 ++++++------ tests/test_mcp_server.py | 1 - 4 files changed, 167 insertions(+), 124 deletions(-) diff --git a/src/mcp_proxy/auth.py b/src/mcp_proxy/auth.py index f3d86a9..7446721 100644 --- a/src/mcp_proxy/auth.py +++ b/src/mcp_proxy/auth.py @@ -1,7 +1,9 @@ """Simple authentication middleware for MCP proxy.""" import logging +from collections.abc import Awaitable, Callable +from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -12,12 +14,16 @@ class AuthMiddleware(BaseHTTPMiddleware): """Simple API key authentication middleware.""" - def __init__(self, app, api_key: str | None = None) -> None: + def __init__(self, app: Starlette, api_key: str | None = None) -> None: """Initialize middleware with optional API key.""" super().__init__(app) self.api_key = api_key - async def dispatch(self, request: Request, call_next) -> Response: + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: """Check API key for protected endpoints.""" # Skip auth if no API key configured if not self.api_key: @@ -30,11 +36,7 @@ async def dispatch(self, request: Request, call_next) -> Response: # Check if path needs protection (/sse, /mcp, /messages, /servers/*/sse, /servers/*/mcp) path = request.url.path needs_auth = ( - path.startswith("/sse") or - path.startswith("/mcp") or - path.startswith("/messages") or - "/sse" in path or - "/mcp" in path + path.startswith(("/sse", "/mcp", "/messages")) or "/sse" in path or "/mcp" in path ) if not needs_auth: @@ -47,8 +49,7 @@ async def dispatch(self, request: Request, call_next) -> Response: logger.warning("Auth failed for %s %s", request.method, path) return JSONResponse( {"error": "Unauthorized"}, - status_code=401 + status_code=401, ) return await call_next(request) - diff --git a/src/mcp_proxy/mcp_server.py b/src/mcp_proxy/mcp_server.py index 0d553b3..963c8ce 100644 --- a/src/mcp_proxy/mcp_server.py +++ b/src/mcp_proxy/mcp_server.py @@ -100,6 +100,109 @@ async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: return routes, http_session_manager +async def _setup_default_server( + stack: contextlib.AsyncExitStack, + default_server_params: StdioServerParameters, + mcp_settings: MCPServerSettings, +) -> list[BaseRoute]: + """Setup default server and return its routes.""" + logger.info( + "Setting up default server: %s %s", + default_server_params.command, + " ".join(default_server_params.args), + ) + stdio_streams = await stack.enter_async_context(stdio_client(default_server_params)) + session = await stack.enter_async_context(ClientSession(*stdio_streams)) + proxy = await create_proxy_server(session) + + instance_routes, http_manager = create_single_instance_routes( + proxy, + stateless_instance=mcp_settings.stateless, + ) + await stack.enter_async_context(http_manager.run()) + _global_status["server_instances"]["default"] = "configured" + return instance_routes + + +async def _setup_named_servers( + stack: contextlib.AsyncExitStack, + named_server_params: dict[str, StdioServerParameters], + mcp_settings: MCPServerSettings, +) -> list[BaseRoute]: + """Setup named servers and return their routes.""" + routes: list[BaseRoute] = [] + for name, params in named_server_params.items(): + logger.info( + "Setting up named server '%s': %s %s", + name, + params.command, + " ".join(params.args), + ) + stdio_streams_named = await stack.enter_async_context(stdio_client(params)) + session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named)) + proxy_named = await create_proxy_server(session_named) + + instance_routes_named, http_manager_named = create_single_instance_routes( + proxy_named, + stateless_instance=mcp_settings.stateless, + ) + await stack.enter_async_context(http_manager_named.run()) + + # Mount these routes under /servers// + server_mount = Mount(f"/servers/{name}", routes=instance_routes_named) + routes.append(server_mount) + _global_status["server_instances"][name] = "configured" + return routes + + +def _create_middleware(mcp_settings: MCPServerSettings) -> list[Middleware]: + """Create middleware list based on settings.""" + middleware: list[Middleware] = [] + if mcp_settings.api_key: + middleware.append( + Middleware(AuthMiddleware, api_key=mcp_settings.api_key), + ) + if mcp_settings.allow_origins: + middleware.append( + Middleware( + CORSMiddleware, + allow_origins=mcp_settings.allow_origins, + allow_methods=["*"], + allow_headers=["*"], + ), + ) + return middleware + + +def _log_server_urls( + mcp_settings: MCPServerSettings, + default_server_params: StdioServerParameters | None, + named_server_params: dict[str, StdioServerParameters], +) -> None: + """Log the SSE URLs for all configured servers.""" + base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}" + sse_urls = [] + + # Add default server if configured + if default_server_params: + sse_urls.append(f"{base_url}/sse") + + # Add named servers + sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params]) + + # Display the SSE URLs prominently + if sse_urls: + logger.info("Serving MCP Servers via SSE:") + for url in sse_urls: + logger.info(" - %s", url) + + logger.debug( + "Serving incoming MCP requests on %s:%s", + mcp_settings.bind_host, + mcp_settings.port, + ) + + async def run_mcp_server( mcp_settings: MCPServerSettings, default_server_params: StdioServerParameters | None = None, @@ -109,9 +212,14 @@ async def run_mcp_server( if named_server_params is None: named_server_params = {} + if not default_server_params and not named_server_params: + logger.error("No servers configured to run.") + return + all_routes: list[BaseRoute] = [ Route("/status", endpoint=_handle_status), # Global status endpoint ] + # Use AsyncExitStack to manage lifecycles of multiple components async with contextlib.AsyncExitStack() as stack: # Manage lifespans of all StreamableHTTPSessionManagers @@ -124,76 +232,32 @@ async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]: # Setup default server if configured if default_server_params: - logger.info( - "Setting up default server: %s %s", - default_server_params.command, - " ".join(default_server_params.args), + instance_routes = await _setup_default_server( + stack, + default_server_params, + mcp_settings, ) - stdio_streams = await stack.enter_async_context(stdio_client(default_server_params)) - session = await stack.enter_async_context(ClientSession(*stdio_streams)) - proxy = await create_proxy_server(session) - - instance_routes, http_manager = create_single_instance_routes( - proxy, - stateless_instance=mcp_settings.stateless, - ) - await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run() all_routes.extend(instance_routes) - _global_status["server_instances"]["default"] = "configured" # Setup named servers - for name, params in named_server_params.items(): - logger.info( - "Setting up named server '%s': %s %s", - name, - params.command, - " ".join(params.args), - ) - stdio_streams_named = await stack.enter_async_context(stdio_client(params)) - session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named)) - proxy_named = await create_proxy_server(session_named) - - instance_routes_named, http_manager_named = create_single_instance_routes( - proxy_named, - stateless_instance=mcp_settings.stateless, - ) - await stack.enter_async_context( - http_manager_named.run(), - ) # Manage lifespan by calling run() - - # Mount these routes under /servers// - server_mount = Mount(f"/servers/{name}", routes=instance_routes_named) - all_routes.append(server_mount) - _global_status["server_instances"][name] = "configured" - - if not default_server_params and not named_server_params: - logger.error("No servers configured to run.") - return - - middleware: list[Middleware] = [] - if mcp_settings.api_key: - middleware.append( - Middleware(AuthMiddleware, api_key=mcp_settings.api_key), - ) - if mcp_settings.allow_origins: - middleware.append( - Middleware( - CORSMiddleware, - allow_origins=mcp_settings.allow_origins, - allow_methods=["*"], - allow_headers=["*"], - ), - ) + if named_server_params: + named_routes = await _setup_named_servers(stack, named_server_params, mcp_settings) + all_routes.extend(named_routes) + # Create middleware and Starlette app + middleware = _create_middleware(mcp_settings) starlette_app = Starlette( debug=(mcp_settings.log_level == "DEBUG"), routes=all_routes, middleware=middleware, lifespan=combined_lifespan, ) - starlette_app.router.redirect_slashes = False + # Log server URLs + _log_server_urls(mcp_settings, default_server_params, named_server_params) + + # Start the server config = uvicorn.Config( starlette_app, host=mcp_settings.bind_host, @@ -201,28 +265,4 @@ async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]: log_level=mcp_settings.log_level.lower(), ) http_server = uvicorn.Server(config) - - # Print out the SSE URLs for all configured servers - base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}" - sse_urls = [] - - # Add default server if configured - if default_server_params: - sse_urls.append(f"{base_url}/sse") - - # Add named servers - sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params]) - - # Display the SSE URLs prominently - if sse_urls: - # Using print directly for user visibility, with noqa to ignore linter warnings - logger.info("Serving MCP Servers via SSE:") - for url in sse_urls: - logger.info(" - %s", url) - - logger.debug( - "Serving incoming MCP requests on %s:%s", - mcp_settings.bind_host, - mcp_settings.port, - ) await http_server.serve() diff --git a/tests/test_auth_simple.py b/tests/test_auth_simple.py index 0b24b3a..c4df535 100644 --- a/tests/test_auth_simple.py +++ b/tests/test_auth_simple.py @@ -1,26 +1,30 @@ """Simplified tests for authentication middleware.""" -import pytest from starlette.applications import Starlette from starlette.middleware import Middleware +from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from starlette.testclient import TestClient from mcp_proxy.auth import AuthMiddleware +# HTTP status codes +HTTP_OK = 200 +HTTP_UNAUTHORIZED = 401 -async def dummy_endpoint(request): + +async def dummy_endpoint(_request: Request) -> JSONResponse: """Simple endpoint for testing.""" return JSONResponse({"message": "success"}) -async def status_endpoint(request): +async def status_endpoint(_request: Request) -> JSONResponse: """Status endpoint.""" return JSONResponse({"status": "ok"}) -def create_app_without_auth(): +def create_app_without_auth() -> Starlette: """Create app without authentication.""" routes = [ Route("/sse", dummy_endpoint), @@ -32,7 +36,7 @@ def create_app_without_auth(): return Starlette(routes=routes) -def create_app_with_auth(): +def create_app_with_auth() -> Starlette: """Create app with authentication.""" routes = [ Route("/sse", dummy_endpoint), @@ -47,109 +51,108 @@ def create_app_with_auth(): return Starlette(routes=routes, middleware=middleware) -def test_no_auth_allows_all(): +def test_no_auth_allows_all() -> None: """Test that all requests work without authentication configured.""" app = create_app_without_auth() with TestClient(app) as client: - assert client.get("/sse").status_code == 200 - assert client.get("/mcp/test").status_code == 200 - assert client.get("/status").status_code == 200 + assert client.get("/sse").status_code == HTTP_OK + assert client.get("/mcp/test").status_code == HTTP_OK + assert client.get("/status").status_code == HTTP_OK -def test_auth_blocks_protected_endpoints(): +def test_auth_blocks_protected_endpoints() -> None: """Test that protected endpoints are blocked without API key.""" app = create_app_with_auth() with TestClient(app) as client: response = client.get("/sse") - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED assert response.json() == {"error": "Unauthorized"} response = client.get("/mcp/test") - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED response = client.get("/messages/test") - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED -def test_auth_allows_with_key(): +def test_auth_allows_with_key() -> None: """Test that requests work with correct API key.""" app = create_app_with_auth() with TestClient(app) as client: headers = {"x-api-key": "test-api-key"} response = client.get("/sse", headers=headers) - assert response.status_code == 200 + assert response.status_code == HTTP_OK assert response.json() == {"message": "success"} response = client.get("/mcp/test", headers=headers) - assert response.status_code == 200 + assert response.status_code == HTTP_OK -def test_auth_blocks_wrong_key(): +def test_auth_blocks_wrong_key() -> None: """Test that requests are blocked with wrong API key.""" app = create_app_with_auth() with TestClient(app) as client: headers = {"x-api-key": "wrong-key"} response = client.get("/sse", headers=headers) - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED -def test_status_not_protected(): +def test_status_not_protected() -> None: """Test that /status endpoint is not protected.""" app = create_app_with_auth() with TestClient(app) as client: response = client.get("/status") - assert response.status_code == 200 + assert response.status_code == HTTP_OK assert response.json() == {"status": "ok"} -def test_other_endpoints_not_protected(): +def test_other_endpoints_not_protected() -> None: """Test that non-SSE/MCP endpoints are not protected.""" app = create_app_with_auth() with TestClient(app) as client: response = client.get("/other") - assert response.status_code == 200 + assert response.status_code == HTTP_OK -def test_options_allowed(): +def test_options_allowed() -> None: """Test that OPTIONS requests are allowed without auth.""" app = create_app_with_auth() with TestClient(app) as client: response = client.options("/sse") - assert response.status_code != 401 + assert response.status_code != HTTP_UNAUTHORIZED -def test_case_insensitive_header(): +def test_case_insensitive_header() -> None: """Test that API key header is case-insensitive.""" app = create_app_with_auth() with TestClient(app) as client: # Different case variations headers = {"X-API-KEY": "test-api-key"} response = client.get("/sse", headers=headers) - assert response.status_code == 200 + assert response.status_code == HTTP_OK headers = {"X-Api-Key": "test-api-key"} response = client.get("/sse", headers=headers) - assert response.status_code == 200 + assert response.status_code == HTTP_OK -def test_named_servers_protected(): +def test_named_servers_protected() -> None: """Test that named server endpoints are protected.""" app = create_app_with_auth() with TestClient(app) as client: # Without auth response = client.get("/servers/test/sse") - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED response = client.get("/servers/test/mcp") - assert response.status_code == 401 + assert response.status_code == HTTP_UNAUTHORIZED # With auth headers = {"x-api-key": "test-api-key"} response = client.get("/servers/test/sse", headers=headers) - assert response.status_code == 200 + assert response.status_code == HTTP_OK response = client.get("/servers/test/mcp", headers=headers) - assert response.status_code == 200 - + assert response.status_code == HTTP_OK diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 38fc6f1..015125a 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -717,4 +717,3 @@ async def test_run_mcp_server_both_default_and_named_servers( ) mock_server_instance.serve.assert_called_once() - From 218da4f18f75498cca173bb7aab4b0c90a0036c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Torregrosa=20P=C3=A1ez?= <20774994+pablotp@users.noreply.github.com> Date: Mon, 22 Sep 2025 11:18:18 +0200 Subject: [PATCH 4/4] fix: Address PR feedback v2 --- .dockerignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index 0d149f1..c47f2ba 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,4 +5,4 @@ !/src/ !pyproject.toml !uv.lock -!LICENSE \ No newline at end of file +!LICENSE