@@ -43,6 +43,61 @@ def t1(request: Request):
43
43
if i < 5 :
44
44
assert response .text == "test"
45
45
46
+ def test_exempt_when_argument (self , build_starlette_app ):
47
+ app , limiter = build_starlette_app (key_func = get_ipaddr )
48
+
49
+ def return_true ():
50
+ return True
51
+
52
+ def return_false ():
53
+ return False
54
+
55
+ def dynamic (request : Request ):
56
+ user_agent = request .headers .get ("User-Agent" )
57
+ if user_agent is None :
58
+ return False
59
+ return user_agent == "exempt"
60
+
61
+ @limiter .limit ("1/minute" , exempt_when = return_true )
62
+ def always_true (request : Request ):
63
+ return PlainTextResponse ("test" )
64
+
65
+ @limiter .limit ("1/minute" , exempt_when = return_false )
66
+ def always_false (request : Request ):
67
+ return PlainTextResponse ("test" )
68
+
69
+ @limiter .limit ("1/minute" , exempt_when = dynamic )
70
+ def always_dynamic (request : Request ):
71
+ return PlainTextResponse ("test" )
72
+
73
+ app .add_route ("/true" , always_true )
74
+ app .add_route ("/false" , always_false )
75
+ app .add_route ("/dynamic" , always_dynamic )
76
+
77
+ client = TestClient (app )
78
+ # Test always true always exempting
79
+ for i in range (0 , 2 ):
80
+ response = client .get ("/true" )
81
+ assert response .status_code == 200
82
+ assert response .text == "test"
83
+ # Test always false hitting the limit after one hit
84
+ for i in range (0 , 2 ):
85
+ response = client .get ("/false" )
86
+ assert response .status_code == 200 if i < 1 else 429
87
+ if i < 1 :
88
+ assert response .text == "test"
89
+ # Test dynamic not exempting with the correct header
90
+ for i in range (0 , 2 ):
91
+ response = client .get ("/dynamic" , headers = {"User-Agent" : "exempt" })
92
+ assert response .status_code == 200
93
+ assert response .text == "test"
94
+ # Test dynamic exempting with the incorrect header
95
+ for i in range (0 , 2 ):
96
+ response = client .get ("/dynamic" )
97
+ assert response .status_code == 200 if i < 1 else 429
98
+ if i < 1 :
99
+ assert response .text == "test"
100
+
46
101
def test_shared_decorator (self , build_starlette_app ):
47
102
app , limiter = build_starlette_app (key_func = get_ipaddr )
48
103
0 commit comments