|
| 1 | +"""Simplified tests for authentication middleware.""" |
| 2 | + |
| 3 | +import pytest |
| 4 | +from starlette.applications import Starlette |
| 5 | +from starlette.middleware import Middleware |
| 6 | +from starlette.responses import JSONResponse |
| 7 | +from starlette.routing import Route |
| 8 | +from starlette.testclient import TestClient |
| 9 | + |
| 10 | +from mcp_proxy.auth import AuthMiddleware |
| 11 | + |
| 12 | + |
| 13 | +async def dummy_endpoint(request): |
| 14 | + """Simple endpoint for testing.""" |
| 15 | + return JSONResponse({"message": "success"}) |
| 16 | + |
| 17 | + |
| 18 | +async def status_endpoint(request): |
| 19 | + """Status endpoint.""" |
| 20 | + return JSONResponse({"status": "ok"}) |
| 21 | + |
| 22 | + |
| 23 | +def create_app_without_auth(): |
| 24 | + """Create app without authentication.""" |
| 25 | + routes = [ |
| 26 | + Route("/sse", dummy_endpoint), |
| 27 | + Route("/mcp/test", dummy_endpoint), |
| 28 | + Route("/messages/test", dummy_endpoint), |
| 29 | + Route("/status", status_endpoint), |
| 30 | + Route("/other", dummy_endpoint), |
| 31 | + ] |
| 32 | + return Starlette(routes=routes) |
| 33 | + |
| 34 | + |
| 35 | +def create_app_with_auth(): |
| 36 | + """Create app with authentication.""" |
| 37 | + routes = [ |
| 38 | + Route("/sse", dummy_endpoint), |
| 39 | + Route("/mcp/test", dummy_endpoint), |
| 40 | + Route("/messages/test", dummy_endpoint), |
| 41 | + Route("/status", status_endpoint), |
| 42 | + Route("/other", dummy_endpoint), |
| 43 | + Route("/servers/test/sse", dummy_endpoint), |
| 44 | + Route("/servers/test/mcp", dummy_endpoint), |
| 45 | + ] |
| 46 | + middleware = [Middleware(AuthMiddleware, api_key="test-api-key")] |
| 47 | + return Starlette(routes=routes, middleware=middleware) |
| 48 | + |
| 49 | + |
| 50 | +def test_no_auth_allows_all(): |
| 51 | + """Test that all requests work without authentication configured.""" |
| 52 | + app = create_app_without_auth() |
| 53 | + with TestClient(app) as client: |
| 54 | + assert client.get("/sse").status_code == 200 |
| 55 | + assert client.get("/mcp/test").status_code == 200 |
| 56 | + assert client.get("/status").status_code == 200 |
| 57 | + |
| 58 | + |
| 59 | +def test_auth_blocks_protected_endpoints(): |
| 60 | + """Test that protected endpoints are blocked without API key.""" |
| 61 | + app = create_app_with_auth() |
| 62 | + with TestClient(app) as client: |
| 63 | + response = client.get("/sse") |
| 64 | + assert response.status_code == 401 |
| 65 | + assert response.json() == {"error": "Unauthorized"} |
| 66 | + |
| 67 | + response = client.get("/mcp/test") |
| 68 | + assert response.status_code == 401 |
| 69 | + |
| 70 | + response = client.get("/messages/test") |
| 71 | + assert response.status_code == 401 |
| 72 | + |
| 73 | + |
| 74 | +def test_auth_allows_with_key(): |
| 75 | + """Test that requests work with correct API key.""" |
| 76 | + app = create_app_with_auth() |
| 77 | + with TestClient(app) as client: |
| 78 | + headers = {"x-api-key": "test-api-key"} |
| 79 | + |
| 80 | + response = client.get("/sse", headers=headers) |
| 81 | + assert response.status_code == 200 |
| 82 | + assert response.json() == {"message": "success"} |
| 83 | + |
| 84 | + response = client.get("/mcp/test", headers=headers) |
| 85 | + assert response.status_code == 200 |
| 86 | + |
| 87 | + |
| 88 | +def test_auth_blocks_wrong_key(): |
| 89 | + """Test that requests are blocked with wrong API key.""" |
| 90 | + app = create_app_with_auth() |
| 91 | + with TestClient(app) as client: |
| 92 | + headers = {"x-api-key": "wrong-key"} |
| 93 | + |
| 94 | + response = client.get("/sse", headers=headers) |
| 95 | + assert response.status_code == 401 |
| 96 | + |
| 97 | + |
| 98 | +def test_status_not_protected(): |
| 99 | + """Test that /status endpoint is not protected.""" |
| 100 | + app = create_app_with_auth() |
| 101 | + with TestClient(app) as client: |
| 102 | + response = client.get("/status") |
| 103 | + assert response.status_code == 200 |
| 104 | + assert response.json() == {"status": "ok"} |
| 105 | + |
| 106 | + |
| 107 | +def test_other_endpoints_not_protected(): |
| 108 | + """Test that non-SSE/MCP endpoints are not protected.""" |
| 109 | + app = create_app_with_auth() |
| 110 | + with TestClient(app) as client: |
| 111 | + response = client.get("/other") |
| 112 | + assert response.status_code == 200 |
| 113 | + |
| 114 | + |
| 115 | +def test_options_allowed(): |
| 116 | + """Test that OPTIONS requests are allowed without auth.""" |
| 117 | + app = create_app_with_auth() |
| 118 | + with TestClient(app) as client: |
| 119 | + response = client.options("/sse") |
| 120 | + assert response.status_code != 401 |
| 121 | + |
| 122 | + |
| 123 | +def test_case_insensitive_header(): |
| 124 | + """Test that API key header is case-insensitive.""" |
| 125 | + app = create_app_with_auth() |
| 126 | + with TestClient(app) as client: |
| 127 | + # Different case variations |
| 128 | + headers = {"X-API-KEY": "test-api-key"} |
| 129 | + response = client.get("/sse", headers=headers) |
| 130 | + assert response.status_code == 200 |
| 131 | + |
| 132 | + headers = {"X-Api-Key": "test-api-key"} |
| 133 | + response = client.get("/sse", headers=headers) |
| 134 | + assert response.status_code == 200 |
| 135 | + |
| 136 | + |
| 137 | +def test_named_servers_protected(): |
| 138 | + """Test that named server endpoints are protected.""" |
| 139 | + app = create_app_with_auth() |
| 140 | + with TestClient(app) as client: |
| 141 | + # Without auth |
| 142 | + response = client.get("/servers/test/sse") |
| 143 | + assert response.status_code == 401 |
| 144 | + |
| 145 | + response = client.get("/servers/test/mcp") |
| 146 | + assert response.status_code == 401 |
| 147 | + |
| 148 | + # With auth |
| 149 | + headers = {"x-api-key": "test-api-key"} |
| 150 | + response = client.get("/servers/test/sse", headers=headers) |
| 151 | + assert response.status_code == 200 |
| 152 | + |
| 153 | + response = client.get("/servers/test/mcp", headers=headers) |
| 154 | + assert response.status_code == 200 |
| 155 | + |
0 commit comments