1
1
from datetime import timedelta
2
+ from logging import getLogger
3
+ from secrets import token_urlsafe
4
+ from typing import List
2
5
3
6
from fastapi import APIRouter , Depends , HTTPException , Request , Security
4
7
from fastapi .responses import HTMLResponse , JSONResponse , RedirectResponse
5
- from pydantic import Field , PrivateAttr
8
+ from pydantic import Field , PrivateAttr , field_validator
6
9
from starlette .status import HTTP_403_FORBIDDEN
7
10
8
11
from csp_gateway .server import GatewayChannels , GatewayModule
9
12
13
+ from ..shared import ChannelSelection
14
+
10
15
# separate to avoid circular
11
16
from ..web import GatewayWebApp
12
17
from .hacks .api_key_middleware_websocket_fix .api_key import (
15
20
APIKeyQuery ,
16
21
)
17
22
23
+ _log = getLogger (__name__ )
18
24
19
- class MountAPIKeyMiddleware (GatewayModule ):
20
- api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
25
+ __all__ = (
26
+ "MountAuthMiddleware" ,
27
+ "MountAPIKeyMiddleware" ,
28
+ )
29
+
30
+ # TODO: More eventually
31
+
32
+
33
+ class MountAuthMiddleware (GatewayModule ):
34
+ enforce : list = Field (default = (), description = "Routes to enforce, default empty means 'all'" )
35
+ channels : ChannelSelection = Field (
36
+ default_factory = ChannelSelection ,
37
+ description = "Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'" ,
38
+ )
21
39
22
- # NOTE: don't make this publically configureable
23
- # as it is needed in gateway.py
24
- _api_key_name : str = PrivateAttr ("token" )
25
- _api_key_secret : str = PrivateAttr ("" )
40
+ enforce_controls : bool = Field (default = False , description = "Whether to allow access to controls routes. Defaults to True" )
41
+ enforce_ui : bool = Field (default = True , description = "Whether to allow web access to the API Key authentication routes. Defaults to True" )
26
42
27
43
unauthorized_status_message : str = "unauthorized"
28
44
45
+ _enforced_channels : List [str ] = PrivateAttr (default_factory = list )
46
+
29
47
def connect (self , channels : GatewayChannels ) -> None :
30
48
# NO-OP
31
49
...
32
50
51
+
52
+ class MountAPIKeyMiddleware (MountAuthMiddleware ):
53
+ api_key : str = Field (default = token_urlsafe (32 ), description = "API Key to use" )
54
+ api_key_name : str = Field (default = "token" , description = "API Key to use" )
55
+ api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
56
+
57
+ _instance_count = 0
58
+
59
+ @field_validator ("api_key_name" , mode = "before" )
60
+ @classmethod
61
+ def _validate_api_key_name (cls , value : str ) -> str :
62
+ if not value :
63
+ raise ValueError ("API Key name must be a non-empty string" )
64
+ value = f"{ value .strip ().lower ()} -{ cls ._instance_count } "
65
+ cls ._instance_count += 1
66
+ return value
67
+
33
68
def rest (self , app : GatewayWebApp ) -> None :
34
69
if app .settings .AUTHENTICATE :
35
- # first, pull out the api key secret from the settings
36
- self ._api_key_secret = app .settings .API_KEY
37
-
38
- # reinitialize header
39
- api_key_query = APIKeyQuery (name = self ._api_key_name , auto_error = False )
40
- api_key_header = APIKeyHeader (name = self ._api_key_name , auto_error = False )
41
- api_key_cookie = APIKeyCookie (name = self ._api_key_name , auto_error = False )
42
-
43
- # routers
44
- auth_router : APIRouter = app .get_router ("auth" )
45
- public_router : APIRouter = app .get_router ("public" )
46
-
47
- # now mount middleware
48
- async def get_api_key (
49
- api_key_query : str = Security (api_key_query ),
50
- api_key_header : str = Security (api_key_header ),
51
- api_key_cookie : str = Security (api_key_cookie ),
52
- ):
53
- if api_key_query == self ._api_key_secret or api_key_header == self ._api_key_secret or api_key_cookie == self ._api_key_secret :
54
- return self ._api_key_secret
55
- else :
70
+ # Use configuration to determine allowed routes
71
+ # for this API key
72
+ self ._calculate_auth (app )
73
+
74
+ # Setup the routes for authentication
75
+ self ._setup_routes (app )
76
+
77
+ def _calculate_auth (self , app : GatewayWebApp ) -> None :
78
+ self ._enforced_channels = self .channels .select_from (app .gateway .channels_model )
79
+
80
+ # Fully form the url
81
+ self ._api_str = app .settings .API_STR
82
+
83
+ def _setup_routes (self , app : GatewayWebApp ) -> None :
84
+ # reinitialize header
85
+ api_key_query = APIKeyQuery (name = self .api_key_name , auto_error = False )
86
+ api_key_header = APIKeyHeader (name = self .api_key_name , auto_error = False )
87
+ api_key_cookie = APIKeyCookie (name = self .api_key_name , auto_error = False )
88
+
89
+ # routers
90
+ auth_router : APIRouter = app .get_router ("auth" )
91
+ public_router : APIRouter = app .get_router ("public" )
92
+
93
+ # now mount middleware
94
+ async def get_api_key (
95
+ request : Request = None ,
96
+ api_key_query : str = Security (api_key_query ),
97
+ api_key_header : str = Security (api_key_header ),
98
+ api_key_cookie : str = Security (api_key_cookie ),
99
+ ):
100
+ if request is None :
101
+ # If request is None, we are not in a request context, return None
102
+ _log .warning ("API Key check: request is None, returning None" )
103
+ return None
104
+
105
+ if hasattr (request .state , "auth" ):
106
+ # Already authenticated, return the API key
107
+ _log .info (f"API Key check: already authenticated, returning { self .api_key_name } " )
108
+ return request .state .auth
109
+
110
+ resolved_path = request .url .path .rstrip ("/" ).replace (self ._api_str , "" ).lstrip ("/" ).rsplit ("/" , 1 )
111
+
112
+ if len (resolved_path ) == 1 :
113
+ root = resolved_path [0 ]
114
+ channel = ""
115
+
116
+ elif len (resolved_path ) > 1 :
117
+ root = resolved_path [0 ]
118
+ channel = resolved_path [1 ]
119
+
120
+ if self .enforce and root not in self .enforce :
121
+ # Route not in enforce, allow
122
+ _log .info (f"API Key check: { root } /{ channel } not in enforced list { self .enforce } , allowing" )
123
+ return ""
124
+
125
+ if root == "controls" and not self .enforce_controls :
126
+ # Controls route not enforced, allow
127
+ _log .info (f"API Key check: root { root } not enforced, allowing" )
128
+ return ""
129
+
130
+ # TODO
131
+ if root in ("" , "auth" , "perspective" ) and not self .enforce_ui :
132
+ # UI route not enforced, allow
133
+ _log .info (f"API Key check: root { root } not enforced, allowing" )
134
+ return ""
135
+
136
+ if root not in ("controls" , "auth" , "perspective" ) and channel and channel not in self ._enforced_channels :
137
+ # Channel not in enforce, allow
138
+ _log .info (f"API Key check: channel { root } /{ channel } not in enforced channels { self ._enforced_channels } , allowing" )
139
+ return ""
140
+
141
+ # Else, enforce
142
+ if api_key_query == self .api_key or api_key_header == self .api_key or api_key_cookie == self .api_key :
143
+ # Return the API key secret to allow access
144
+ _log .info (f"API Key check: { self .api_key_name } matched for { root } /{ channel } , allowing access" )
145
+
146
+ # NOTE: only set this if we are the one validating, not if we are ignoring
147
+ request .state .auth = self .api_key
148
+ return self .api_key
149
+
150
+ _log .warning (f"API Key check: { self .api_key_name } did not match, denying access" )
151
+ raise HTTPException (
152
+ status_code = HTTP_403_FORBIDDEN ,
153
+ detail = self .unauthorized_status_message ,
154
+ )
155
+
156
+ # add auth to all other routes
157
+ app .add_middleware (Depends (get_api_key ))
158
+
159
+ if self .enforce_ui :
160
+
161
+ @auth_router .get ("/login" )
162
+ async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
163
+ if not api_key :
56
164
raise HTTPException (
57
165
status_code = HTTP_403_FORBIDDEN ,
58
166
detail = self .unauthorized_status_message ,
59
167
)
60
-
61
- @auth_router .get ("/login" )
62
- async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
63
168
response = RedirectResponse (url = "/" )
64
169
response .set_cookie (
65
- self ._api_key_name ,
170
+ self .api_key_name ,
66
171
value = api_key ,
67
172
domain = app .settings .AUTHENTICATION_DOMAIN ,
68
173
httponly = True ,
@@ -74,44 +179,40 @@ async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
74
179
@auth_router .get ("/logout" )
75
180
async def route_logout_and_remove_cookie ():
76
181
response = RedirectResponse (url = "/login" )
77
- response .delete_cookie (self ._api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
182
+ response .delete_cookie (self .api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
78
183
return response
79
184
80
185
# I'm hand rolling these for now...
81
186
@public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
82
187
async def get_login_page (token : str = "" , request : Request = None ):
83
- if token :
84
- if token != "" :
85
- return RedirectResponse (url = f"{ app .settings .API_V1_STR } /auth/login?token={ token } " )
188
+ if token and token != "" :
189
+ return RedirectResponse (url = f"{ self ._api_str } /auth/login?token={ token } " )
86
190
return app .templates .TemplateResponse (
87
191
"login.html.j2" ,
88
- {"request" : request , "api_key_name" : self ._api_key_name },
192
+ {"request" : request , "api_key_name" : self .api_key_name },
89
193
)
90
194
91
195
@public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
92
196
async def get_logout_page (request : Request = None ):
93
197
return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
94
198
95
- # add auth to all other routes
96
- app .add_middleware (Depends (get_api_key ))
97
-
98
- @app .app .exception_handler (403 )
99
- async def custom_403_handler (request : Request = None , * args ):
100
- if "/api" in request .url .path :
101
- # programmatic api access, return json
102
- return JSONResponse (
103
- {
104
- "detail" : self .unauthorized_status_message ,
105
- "status_code" : 403 ,
106
- },
107
- status_code = 403 ,
108
- )
109
- return app .templates .TemplateResponse (
110
- "login.html.j2" ,
199
+ @app .app .exception_handler (403 )
200
+ async def custom_403_handler (request : Request = None , * args ):
201
+ if "/api" in request .url .path :
202
+ # programmatic api access, return json
203
+ return JSONResponse (
111
204
{
112
- "request" : request ,
113
- "api_key_name" : self ._api_key_name ,
114
- "status_code" : 403 ,
115
205
"detail" : self .unauthorized_status_message ,
206
+ "status_code" : 403 ,
116
207
},
208
+ status_code = 403 ,
117
209
)
210
+ return app .templates .TemplateResponse (
211
+ "login.html.j2" ,
212
+ {
213
+ "request" : request ,
214
+ "api_key_name" : self .api_key_name ,
215
+ "status_code" : 403 ,
216
+ "detail" : self .unauthorized_status_message ,
217
+ },
218
+ )
0 commit comments