Skip to content

Commit 4f9c96a

Browse files
committed
fix: Address PR feedback
1 parent 85c5685 commit 4f9c96a

File tree

4 files changed

+167
-124
lines changed

4 files changed

+167
-124
lines changed

src/mcp_proxy/auth.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Simple authentication middleware for MCP proxy."""
22

33
import logging
4+
from collections.abc import Awaitable, Callable
45

6+
from starlette.applications import Starlette
57
from starlette.middleware.base import BaseHTTPMiddleware
68
from starlette.requests import Request
79
from starlette.responses import JSONResponse, Response
@@ -12,12 +14,16 @@
1214
class AuthMiddleware(BaseHTTPMiddleware):
1315
"""Simple API key authentication middleware."""
1416

15-
def __init__(self, app, api_key: str | None = None) -> None:
17+
def __init__(self, app: Starlette, api_key: str | None = None) -> None:
1618
"""Initialize middleware with optional API key."""
1719
super().__init__(app)
1820
self.api_key = api_key
1921

20-
async def dispatch(self, request: Request, call_next) -> Response:
22+
async def dispatch(
23+
self,
24+
request: Request,
25+
call_next: Callable[[Request], Awaitable[Response]],
26+
) -> Response:
2127
"""Check API key for protected endpoints."""
2228
# Skip auth if no API key configured
2329
if not self.api_key:
@@ -30,11 +36,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
3036
# Check if path needs protection (/sse, /mcp, /messages, /servers/*/sse, /servers/*/mcp)
3137
path = request.url.path
3238
needs_auth = (
33-
path.startswith("/sse") or
34-
path.startswith("/mcp") or
35-
path.startswith("/messages") or
36-
"/sse" in path or
37-
"/mcp" in path
39+
path.startswith(("/sse", "/mcp", "/messages")) or "/sse" in path or "/mcp" in path
3840
)
3941

4042
if not needs_auth:
@@ -47,8 +49,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
4749
logger.warning("Auth failed for %s %s", request.method, path)
4850
return JSONResponse(
4951
{"error": "Unauthorized"},
50-
status_code=401
52+
status_code=401,
5153
)
5254

5355
return await call_next(request)
54-

src/mcp_proxy/mcp_server.py

Lines changed: 121 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,109 @@ async def handle_streamable_http_instance(scope: Scope, receive: Receive, send:
100100
return routes, http_session_manager
101101

102102

103+
async def _setup_default_server(
104+
stack: contextlib.AsyncExitStack,
105+
default_server_params: StdioServerParameters,
106+
mcp_settings: MCPServerSettings,
107+
) -> list[BaseRoute]:
108+
"""Setup default server and return its routes."""
109+
logger.info(
110+
"Setting up default server: %s %s",
111+
default_server_params.command,
112+
" ".join(default_server_params.args),
113+
)
114+
stdio_streams = await stack.enter_async_context(stdio_client(default_server_params))
115+
session = await stack.enter_async_context(ClientSession(*stdio_streams))
116+
proxy = await create_proxy_server(session)
117+
118+
instance_routes, http_manager = create_single_instance_routes(
119+
proxy,
120+
stateless_instance=mcp_settings.stateless,
121+
)
122+
await stack.enter_async_context(http_manager.run())
123+
_global_status["server_instances"]["default"] = "configured"
124+
return instance_routes
125+
126+
127+
async def _setup_named_servers(
128+
stack: contextlib.AsyncExitStack,
129+
named_server_params: dict[str, StdioServerParameters],
130+
mcp_settings: MCPServerSettings,
131+
) -> list[BaseRoute]:
132+
"""Setup named servers and return their routes."""
133+
routes: list[BaseRoute] = []
134+
for name, params in named_server_params.items():
135+
logger.info(
136+
"Setting up named server '%s': %s %s",
137+
name,
138+
params.command,
139+
" ".join(params.args),
140+
)
141+
stdio_streams_named = await stack.enter_async_context(stdio_client(params))
142+
session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named))
143+
proxy_named = await create_proxy_server(session_named)
144+
145+
instance_routes_named, http_manager_named = create_single_instance_routes(
146+
proxy_named,
147+
stateless_instance=mcp_settings.stateless,
148+
)
149+
await stack.enter_async_context(http_manager_named.run())
150+
151+
# Mount these routes under /servers/<name>/
152+
server_mount = Mount(f"/servers/{name}", routes=instance_routes_named)
153+
routes.append(server_mount)
154+
_global_status["server_instances"][name] = "configured"
155+
return routes
156+
157+
158+
def _create_middleware(mcp_settings: MCPServerSettings) -> list[Middleware]:
159+
"""Create middleware list based on settings."""
160+
middleware: list[Middleware] = []
161+
if mcp_settings.api_key:
162+
middleware.append(
163+
Middleware(AuthMiddleware, api_key=mcp_settings.api_key),
164+
)
165+
if mcp_settings.allow_origins:
166+
middleware.append(
167+
Middleware(
168+
CORSMiddleware,
169+
allow_origins=mcp_settings.allow_origins,
170+
allow_methods=["*"],
171+
allow_headers=["*"],
172+
),
173+
)
174+
return middleware
175+
176+
177+
def _log_server_urls(
178+
mcp_settings: MCPServerSettings,
179+
default_server_params: StdioServerParameters | None,
180+
named_server_params: dict[str, StdioServerParameters],
181+
) -> None:
182+
"""Log the SSE URLs for all configured servers."""
183+
base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}"
184+
sse_urls = []
185+
186+
# Add default server if configured
187+
if default_server_params:
188+
sse_urls.append(f"{base_url}/sse")
189+
190+
# Add named servers
191+
sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params])
192+
193+
# Display the SSE URLs prominently
194+
if sse_urls:
195+
logger.info("Serving MCP Servers via SSE:")
196+
for url in sse_urls:
197+
logger.info(" - %s", url)
198+
199+
logger.debug(
200+
"Serving incoming MCP requests on %s:%s",
201+
mcp_settings.bind_host,
202+
mcp_settings.port,
203+
)
204+
205+
103206
async def run_mcp_server(
104207
mcp_settings: MCPServerSettings,
105208
default_server_params: StdioServerParameters | None = None,
@@ -109,9 +212,14 @@ async def run_mcp_server(
109212
if named_server_params is None:
110213
named_server_params = {}
111214

215+
if not default_server_params and not named_server_params:
216+
logger.error("No servers configured to run.")
217+
return
218+
112219
all_routes: list[BaseRoute] = [
113220
Route("/status", endpoint=_handle_status), # Global status endpoint
114221
]
222+
115223
# Use AsyncExitStack to manage lifecycles of multiple components
116224
async with contextlib.AsyncExitStack() as stack:
117225
# Manage lifespans of all StreamableHTTPSessionManagers
@@ -124,105 +232,37 @@ async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]:
124232

125233
# Setup default server if configured
126234
if default_server_params:
127-
logger.info(
128-
"Setting up default server: %s %s",
129-
default_server_params.command,
130-
" ".join(default_server_params.args),
235+
instance_routes = await _setup_default_server(
236+
stack,
237+
default_server_params,
238+
mcp_settings,
131239
)
132-
stdio_streams = await stack.enter_async_context(stdio_client(default_server_params))
133-
session = await stack.enter_async_context(ClientSession(*stdio_streams))
134-
proxy = await create_proxy_server(session)
135-
136-
instance_routes, http_manager = create_single_instance_routes(
137-
proxy,
138-
stateless_instance=mcp_settings.stateless,
139-
)
140-
await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run()
141240
all_routes.extend(instance_routes)
142-
_global_status["server_instances"]["default"] = "configured"
143241

144242
# Setup named servers
145-
for name, params in named_server_params.items():
146-
logger.info(
147-
"Setting up named server '%s': %s %s",
148-
name,
149-
params.command,
150-
" ".join(params.args),
151-
)
152-
stdio_streams_named = await stack.enter_async_context(stdio_client(params))
153-
session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named))
154-
proxy_named = await create_proxy_server(session_named)
155-
156-
instance_routes_named, http_manager_named = create_single_instance_routes(
157-
proxy_named,
158-
stateless_instance=mcp_settings.stateless,
159-
)
160-
await stack.enter_async_context(
161-
http_manager_named.run(),
162-
) # Manage lifespan by calling run()
163-
164-
# Mount these routes under /servers/<name>/
165-
server_mount = Mount(f"/servers/{name}", routes=instance_routes_named)
166-
all_routes.append(server_mount)
167-
_global_status["server_instances"][name] = "configured"
168-
169-
if not default_server_params and not named_server_params:
170-
logger.error("No servers configured to run.")
171-
return
172-
173-
middleware: list[Middleware] = []
174-
if mcp_settings.api_key:
175-
middleware.append(
176-
Middleware(AuthMiddleware, api_key=mcp_settings.api_key),
177-
)
178-
if mcp_settings.allow_origins:
179-
middleware.append(
180-
Middleware(
181-
CORSMiddleware,
182-
allow_origins=mcp_settings.allow_origins,
183-
allow_methods=["*"],
184-
allow_headers=["*"],
185-
),
186-
)
243+
if named_server_params:
244+
named_routes = await _setup_named_servers(stack, named_server_params, mcp_settings)
245+
all_routes.extend(named_routes)
187246

247+
# Create middleware and Starlette app
248+
middleware = _create_middleware(mcp_settings)
188249
starlette_app = Starlette(
189250
debug=(mcp_settings.log_level == "DEBUG"),
190251
routes=all_routes,
191252
middleware=middleware,
192253
lifespan=combined_lifespan,
193254
)
194-
195255
starlette_app.router.redirect_slashes = False
196256

257+
# Log server URLs
258+
_log_server_urls(mcp_settings, default_server_params, named_server_params)
259+
260+
# Start the server
197261
config = uvicorn.Config(
198262
starlette_app,
199263
host=mcp_settings.bind_host,
200264
port=mcp_settings.port,
201265
log_level=mcp_settings.log_level.lower(),
202266
)
203267
http_server = uvicorn.Server(config)
204-
205-
# Print out the SSE URLs for all configured servers
206-
base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}"
207-
sse_urls = []
208-
209-
# Add default server if configured
210-
if default_server_params:
211-
sse_urls.append(f"{base_url}/sse")
212-
213-
# Add named servers
214-
sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params])
215-
216-
# Display the SSE URLs prominently
217-
if sse_urls:
218-
# Using print directly for user visibility, with noqa to ignore linter warnings
219-
logger.info("Serving MCP Servers via SSE:")
220-
for url in sse_urls:
221-
logger.info(" - %s", url)
222-
223-
logger.debug(
224-
"Serving incoming MCP requests on %s:%s",
225-
mcp_settings.bind_host,
226-
mcp_settings.port,
227-
)
228268
await http_server.serve()

0 commit comments

Comments
 (0)