Skip to content

Commit 09d4247

Browse files
committed
[WIP] Allow for multiple api key middlewares
Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com>
1 parent 8fb8f68 commit 09d4247

File tree

7 files changed

+195
-69
lines changed

7 files changed

+195
-69
lines changed

csp_gateway/server/config/gateway/omnibus.yaml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,28 @@ modules:
3030
_target_: csp_gateway.MountWebSocketRoutes
3131
mount_api_key_middleware:
3232
_target_: csp_gateway.MountAPIKeyMiddleware
33+
api_key: 12345
34+
enforce_ui: false
35+
enforce_controls: false
36+
mount_api_key_middleware_ui:
37+
_target_: csp_gateway.MountAPIKeyMiddleware
38+
api_key: token
39+
enforce: []
40+
enforce_ui: true
41+
enforce_controls: false
42+
mount_api_key_middleware_controls:
43+
_target_: csp_gateway.MountAPIKeyMiddleware
44+
api_key: 54321
45+
enforce: []
46+
enforce_ui: false
47+
enforce_controls: true
3348

3449
gateway:
3550
_target_: csp_gateway.Gateway
3651
settings:
3752
PORT: ${port}
3853
AUTHENTICATE: ${authenticate}
3954
UI: True
40-
API_KEY: "12345"
4155
modules:
4256
- /modules/example_module
4357
- /modules/example_module_feedback
@@ -49,6 +63,8 @@ gateway:
4963
- /modules/mount_rest_routes
5064
- /modules/mount_websocket_routes
5165
- /modules/mount_api_key_middleware
66+
- /modules/mount_api_key_middleware_ui
67+
- /modules/mount_api_key_middleware_controls
5268
channels:
5369
_target_: csp_gateway.server.demo.ExampleGatewayChannels
5470

csp_gateway/server/demo/config/omnibus.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defaults:
33
- /gateway: omnibus
44
- _self_
55

6-
# csp-gateway-start --config-dir=csp_gateway/server/omnibus +config=omnibus
6+
# csp-gateway-start --config-dir=csp_gateway/server/demo +config=omnibus
77

8-
authenticate: False
8+
authenticate: True
99
port: 8000

csp_gateway/server/demo/omnibus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def push_to_perspective( # type: ignore[no-untyped-def]
279279
# be instantiated directly as we do so here:
280280

281281
# Setting authentication
282-
settings = GatewaySettings(API_KEY="12345", AUTHENTICATE=False)
282+
settings = GatewaySettings(AUTHENTICATE=False)
283283

284284
# instantiate gateway
285285
gateway = Gateway(

csp_gateway/server/gateway/gateway.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,26 @@ def start(
275275
log.info("Launching web server on:")
276276
url = f"http://{gethostname()}:{self.settings.PORT}"
277277

278-
if ui:
279-
if self.settings.AUTHENTICATE:
280-
log.info(f"\tUI: {url}?token={self.settings.API_KEY}")
278+
if ui and self.settings.AUTHENTICATE:
279+
from ..middleware import MountAPIKeyMiddleware
280+
281+
# TODO: Will need to handle others
282+
auth = ""
283+
284+
# Find any middleware enforcing auth
285+
for module in self.modules:
286+
if isinstance(module, MountAPIKeyMiddleware) and module.enforce_ui is True:
287+
auth = module.api_key
288+
break
289+
290+
if auth:
291+
log.info(f"\tUI: {url}?{module.api_key_name}={auth}")
281292
else:
282293
log.info(f"\tUI: {url}")
283294

295+
else:
296+
log.info(f"\tUI: {url}")
297+
284298
log.info(f"\tDocs: {url}/docs")
285299
log.info(f"\tDocs: {url}/redoc")
286300

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .api_key import MountAPIKeyMiddleware
1+
from .api_key import *
Lines changed: 157 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from datetime import timedelta
2+
from logging import getLogger
3+
from secrets import token_urlsafe
4+
from typing import List
25

36
from fastapi import APIRouter, Depends, HTTPException, Request, Security
47
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
5-
from pydantic import Field, PrivateAttr
8+
from pydantic import Field, PrivateAttr, field_validator
69
from starlette.status import HTTP_403_FORBIDDEN
710

811
from csp_gateway.server import GatewayChannels, GatewayModule
912

13+
from ..shared import ChannelSelection
14+
1015
# separate to avoid circular
1116
from ..web import GatewayWebApp
1217
from .hacks.api_key_middleware_websocket_fix.api_key import (
@@ -15,54 +20,154 @@
1520
APIKeyQuery,
1621
)
1722

23+
_log = getLogger(__name__)
1824

19-
class MountAPIKeyMiddleware(GatewayModule):
20-
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))
25+
__all__ = (
26+
"MountAuthMiddleware",
27+
"MountAPIKeyMiddleware",
28+
)
29+
30+
# TODO: More eventually
31+
32+
33+
class MountAuthMiddleware(GatewayModule):
34+
enforce: list = Field(default=(), description="Routes to enforce, default empty means 'all'")
35+
channels: ChannelSelection = Field(
36+
default_factory=ChannelSelection,
37+
description="Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'",
38+
)
2139

22-
# NOTE: don't make this publically configureable
23-
# as it is needed in gateway.py
24-
_api_key_name: str = PrivateAttr("token")
25-
_api_key_secret: str = PrivateAttr("")
40+
enforce_controls: bool = Field(default=False, description="Whether to allow access to controls routes. Defaults to True")
41+
enforce_ui: bool = Field(default=True, description="Whether to allow web access to the API Key authentication routes. Defaults to True")
2642

2743
unauthorized_status_message: str = "unauthorized"
2844

45+
_enforced_channels: List[str] = PrivateAttr(default_factory=list)
46+
2947
def connect(self, channels: GatewayChannels) -> None:
3048
# NO-OP
3149
...
3250

51+
52+
class MountAPIKeyMiddleware(MountAuthMiddleware):
53+
api_key: str = Field(default=token_urlsafe(32), description="API Key to use")
54+
api_key_name: str = Field(default="token", description="API Key to use")
55+
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))
56+
57+
_instance_count = 0
58+
59+
@field_validator("api_key_name", mode="before")
60+
@classmethod
61+
def _validate_api_key_name(cls, value: str) -> str:
62+
if not value:
63+
raise ValueError("API Key name must be a non-empty string")
64+
value = f"{value.strip().lower()}-{cls._instance_count}"
65+
cls._instance_count += 1
66+
return value
67+
3368
def rest(self, app: GatewayWebApp) -> None:
3469
if app.settings.AUTHENTICATE:
35-
# first, pull out the api key secret from the settings
36-
self._api_key_secret = app.settings.API_KEY
37-
38-
# reinitialize header
39-
api_key_query = APIKeyQuery(name=self._api_key_name, auto_error=False)
40-
api_key_header = APIKeyHeader(name=self._api_key_name, auto_error=False)
41-
api_key_cookie = APIKeyCookie(name=self._api_key_name, auto_error=False)
42-
43-
# routers
44-
auth_router: APIRouter = app.get_router("auth")
45-
public_router: APIRouter = app.get_router("public")
46-
47-
# now mount middleware
48-
async def get_api_key(
49-
api_key_query: str = Security(api_key_query),
50-
api_key_header: str = Security(api_key_header),
51-
api_key_cookie: str = Security(api_key_cookie),
52-
):
53-
if api_key_query == self._api_key_secret or api_key_header == self._api_key_secret or api_key_cookie == self._api_key_secret:
54-
return self._api_key_secret
55-
else:
70+
# Use configuration to determine allowed routes
71+
# for this API key
72+
self._calculate_auth(app)
73+
74+
# Setup the routes for authentication
75+
self._setup_routes(app)
76+
77+
def _calculate_auth(self, app: GatewayWebApp) -> None:
78+
self._enforced_channels = self.channels.select_from(app.gateway.channels_model)
79+
80+
# Fully form the url
81+
self._api_str = app.settings.API_STR
82+
83+
def _setup_routes(self, app: GatewayWebApp) -> None:
84+
# reinitialize header
85+
api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False)
86+
api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False)
87+
api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False)
88+
89+
# routers
90+
auth_router: APIRouter = app.get_router("auth")
91+
public_router: APIRouter = app.get_router("public")
92+
93+
# now mount middleware
94+
async def get_api_key(
95+
request: Request = None,
96+
api_key_query: str = Security(api_key_query),
97+
api_key_header: str = Security(api_key_header),
98+
api_key_cookie: str = Security(api_key_cookie),
99+
):
100+
if request is None:
101+
# If request is None, we are not in a request context, return None
102+
_log.warning("API Key check: request is None, returning None")
103+
return None
104+
105+
if hasattr(request.state, "auth"):
106+
# Already authenticated, return the API key
107+
_log.info(f"API Key check: already authenticated, returning {self.api_key_name}")
108+
return request.state.auth
109+
110+
resolved_path = request.url.path.rstrip("/").replace(self._api_str, "").lstrip("/").rsplit("/", 1)
111+
112+
if len(resolved_path) == 1:
113+
root = resolved_path[0]
114+
channel = ""
115+
116+
elif len(resolved_path) > 1:
117+
root = resolved_path[0]
118+
channel = resolved_path[1]
119+
120+
if self.enforce and root not in self.enforce:
121+
# Route not in enforce, allow
122+
_log.info(f"API Key check: {root}/{channel} not in enforced list {self.enforce}, allowing")
123+
return ""
124+
125+
if root == "controls" and not self.enforce_controls:
126+
# Controls route not enforced, allow
127+
_log.info(f"API Key check: root {root} not enforced, allowing")
128+
return ""
129+
130+
# TODO
131+
if root in ("", "auth", "perspective") and not self.enforce_ui:
132+
# UI route not enforced, allow
133+
_log.info(f"API Key check: root {root} not enforced, allowing")
134+
return ""
135+
136+
if root not in ("controls", "auth", "perspective") and channel and channel not in self._enforced_channels:
137+
# Channel not in enforce, allow
138+
_log.info(f"API Key check: channel {root}/{channel} not in enforced channels {self._enforced_channels}, allowing")
139+
return ""
140+
141+
# Else, enforce
142+
if api_key_query == self.api_key or api_key_header == self.api_key or api_key_cookie == self.api_key:
143+
# Return the API key secret to allow access
144+
_log.info(f"API Key check: {self.api_key_name} matched for {root}/{channel}, allowing access")
145+
146+
# NOTE: only set this if we are the one validating, not if we are ignoring
147+
request.state.auth = self.api_key
148+
return self.api_key
149+
150+
_log.warning(f"API Key check: {self.api_key_name} did not match, denying access")
151+
raise HTTPException(
152+
status_code=HTTP_403_FORBIDDEN,
153+
detail=self.unauthorized_status_message,
154+
)
155+
156+
# add auth to all other routes
157+
app.add_middleware(Depends(get_api_key))
158+
159+
if self.enforce_ui:
160+
161+
@auth_router.get("/login")
162+
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
163+
if not api_key:
56164
raise HTTPException(
57165
status_code=HTTP_403_FORBIDDEN,
58166
detail=self.unauthorized_status_message,
59167
)
60-
61-
@auth_router.get("/login")
62-
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
63168
response = RedirectResponse(url="/")
64169
response.set_cookie(
65-
self._api_key_name,
170+
self.api_key_name,
66171
value=api_key,
67172
domain=app.settings.AUTHENTICATION_DOMAIN,
68173
httponly=True,
@@ -74,44 +179,40 @@ async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
74179
@auth_router.get("/logout")
75180
async def route_logout_and_remove_cookie():
76181
response = RedirectResponse(url="/login")
77-
response.delete_cookie(self._api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
182+
response.delete_cookie(self.api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
78183
return response
79184

80185
# I'm hand rolling these for now...
81186
@public_router.get("/login", response_class=HTMLResponse, include_in_schema=False)
82187
async def get_login_page(token: str = "", request: Request = None):
83-
if token:
84-
if token != "":
85-
return RedirectResponse(url=f"{app.settings.API_V1_STR}/auth/login?token={token}")
188+
if token and token != "":
189+
return RedirectResponse(url=f"{self._api_str}/auth/login?token={token}")
86190
return app.templates.TemplateResponse(
87191
"login.html.j2",
88-
{"request": request, "api_key_name": self._api_key_name},
192+
{"request": request, "api_key_name": self.api_key_name},
89193
)
90194

91195
@public_router.get("/logout", response_class=HTMLResponse, include_in_schema=False)
92196
async def get_logout_page(request: Request = None):
93197
return app.templates.TemplateResponse("logout.html.j2", {"request": request})
94198

95-
# add auth to all other routes
96-
app.add_middleware(Depends(get_api_key))
97-
98-
@app.app.exception_handler(403)
99-
async def custom_403_handler(request: Request = None, *args):
100-
if "/api" in request.url.path:
101-
# programmatic api access, return json
102-
return JSONResponse(
103-
{
104-
"detail": self.unauthorized_status_message,
105-
"status_code": 403,
106-
},
107-
status_code=403,
108-
)
109-
return app.templates.TemplateResponse(
110-
"login.html.j2",
199+
@app.app.exception_handler(403)
200+
async def custom_403_handler(request: Request = None, *args):
201+
if "/api" in request.url.path:
202+
# programmatic api access, return json
203+
return JSONResponse(
111204
{
112-
"request": request,
113-
"api_key_name": self._api_key_name,
114-
"status_code": 403,
115205
"detail": self.unauthorized_status_message,
206+
"status_code": 403,
116207
},
208+
status_code=403,
117209
)
210+
return app.templates.TemplateResponse(
211+
"login.html.j2",
212+
{
213+
"request": request,
214+
"api_key_name": self.api_key_name,
215+
"status_code": 403,
216+
"detail": self.unauthorized_status_message,
217+
},
218+
)

csp_gateway/server/settings.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from secrets import token_urlsafe
21
from socket import gethostname
32
from typing import List
43

@@ -31,8 +30,4 @@ class Settings(BaseSettings):
3130

3231
UI: bool = Field(False, description="Enables ui in the web application")
3332
AUTHENTICATE: bool = Field(False, description="Whether to authenticate users for access to the web application")
34-
API_KEY: str = Field(
35-
token_urlsafe(32),
36-
description="The API key for access if `AUTHENTICATE=True`. The default is auto-generated, but a user-provided value can be used.",
37-
)
3833
AUTHENTICATION_DOMAIN: str = gethostname()

0 commit comments

Comments
 (0)