Skip to content

Commit 57223b7

Browse files
committed
Improved logs, added a test
1 parent c168175 commit 57223b7

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

slowapi/extension.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
The starlette extension to rate-limit requests
33
"""
4+
45
import asyncio
56
import functools
67
import inspect
@@ -734,7 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
734735
if not isinstance(response, Response):
735736
# get the response object from the decorated endpoint function
736737
self._inject_headers(
737-
kwargs.get("response"), request.state.view_rate_limit # type: ignore
738+
kwargs.get("response"),
739+
request.state.view_rate_limit, # type: ignore
738740
)
739741
else:
740742
self._inject_headers(
@@ -766,7 +768,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
766768
if not isinstance(response, Response):
767769
# get the response object from the decorated endpoint function
768770
self._inject_headers(
769-
kwargs.get("response"), request.state.view_rate_limit # type: ignore
771+
kwargs.get("response"),
772+
request.state.view_rate_limit, # type: ignore
770773
)
771774
else:
772775
self._inject_headers(
@@ -803,7 +806,7 @@ def limit(
803806
* **error_message**: string (or callable that returns one) to override the
804807
error message used in the response.
805808
* **exempt_when**: function returning a boolean indicating whether to exempt
806-
the route from the limit
809+
the route from the limit. This function can optionally use a Request object.
807810
* **cost**: integer (or callable that returns one) which is the cost of a hit
808811
* **override_defaults**: whether to override the default limits (default: True)
809812
"""

tests/test_starlette_extension.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,61 @@ def t1(request: Request):
4343
if i < 5:
4444
assert response.text == "test"
4545

46+
def test_exempt_when_argument(self, build_starlette_app):
47+
app, limiter = build_starlette_app(key_func=get_ipaddr)
48+
49+
def return_true():
50+
return True
51+
52+
def return_false():
53+
return False
54+
55+
def dynamic(request: Request):
56+
user_agent = request.headers.get("User-Agent")
57+
if user_agent is None:
58+
return False
59+
return user_agent == "exempt"
60+
61+
@limiter.limit("1/minute", exempt_when=return_true)
62+
def always_true(request: Request):
63+
return PlainTextResponse("test")
64+
65+
@limiter.limit("1/minute", exempt_when=return_false)
66+
def always_false(request: Request):
67+
return PlainTextResponse("test")
68+
69+
@limiter.limit("1/minute", exempt_when=dynamic)
70+
def always_dynamic(request: Request):
71+
return PlainTextResponse("test")
72+
73+
app.add_route("/true", always_true)
74+
app.add_route("/false", always_false)
75+
app.add_route("/dynamic", always_dynamic)
76+
77+
client = TestClient(app)
78+
# Test always true always exempting
79+
for i in range(0, 2):
80+
response = client.get("/true")
81+
assert response.status_code == 200
82+
assert response.text == "test"
83+
# Test always false hitting the limit after one hit
84+
for i in range(0, 2):
85+
response = client.get("/false")
86+
assert response.status_code == 200 if i < 1 else 429
87+
if i < 1:
88+
assert response.text == "test"
89+
# Test dynamic not exempting with the correct header
90+
for i in range(0, 2):
91+
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
92+
assert response.status_code == 200
93+
assert response.text == "test"
94+
# Test dynamic exempting with the incorrect header
95+
for i in range(0, 2):
96+
response = client.get("/dynamic")
97+
assert response.status_code == 200 if i < 1 else 429
98+
if i < 1:
99+
assert response.text == "test"
100+
46101
def test_shared_decorator(self, build_starlette_app):
47102
app, limiter = build_starlette_app(key_func=get_ipaddr)
48103

0 commit comments

Comments
 (0)