@@ -63,7 +63,7 @@ class BucketType(Enum):
63
63
category = 5
64
64
role = 6
65
65
66
- def get_key (self , ctx : Context | ApplicationContext ) -> Any :
66
+ def get_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
67
67
if self is BucketType .user :
68
68
return ctx .author .id
69
69
elif self is BucketType .guild :
@@ -90,7 +90,7 @@ def get_key(self, ctx: Context | ApplicationContext) -> Any:
90
90
else ctx .author .top_role
91
91
).id # type: ignore
92
92
93
- def __call__ (self , ctx : Context | ApplicationContext ) -> Any :
93
+ def __call__ (self , ctx : Context | ApplicationContext | Message ) -> Any :
94
94
return self .get_key (ctx )
95
95
96
96
@@ -215,14 +215,14 @@ class CooldownMapping:
215
215
def __init__ (
216
216
self ,
217
217
original : Cooldown | None ,
218
- type : Callable [[Context | ApplicationContext ], Any ],
218
+ type : Callable [[Context | ApplicationContext | Message ], Any ],
219
219
) -> None :
220
220
if not callable (type ):
221
221
raise TypeError ("Cooldown type must be a BucketType or callable" )
222
222
223
223
self ._cache : dict [Any , Cooldown ] = {}
224
224
self ._cooldown : Cooldown | None = original
225
- self ._type : Callable [[Context | ApplicationContext ], Any ] = type
225
+ self ._type : Callable [[Context | ApplicationContext | Message ], Any ] = type
226
226
227
227
def copy (self ) -> CooldownMapping :
228
228
ret = CooldownMapping (self ._cooldown , self ._type )
@@ -234,14 +234,14 @@ def valid(self) -> bool:
234
234
return self ._cooldown is not None
235
235
236
236
@property
237
- def type (self ) -> Callable [[Context | ApplicationContext ], Any ]:
237
+ def type (self ) -> Callable [[Context | ApplicationContext | Message ], Any ]:
238
238
return self ._type
239
239
240
240
@classmethod
241
241
def from_cooldown (cls : type [C ], rate , per , type ) -> C :
242
242
return cls (Cooldown (rate , per ), type )
243
243
244
- def _bucket_key (self , ctx : Context | ApplicationContext ) -> Any :
244
+ def _bucket_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
245
245
return self ._type (ctx )
246
246
247
247
def _verify_cache_integrity (self , current : float | None = None ) -> None :
@@ -253,11 +253,11 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
253
253
for k in dead_keys :
254
254
del self ._cache [k ]
255
255
256
- async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
256
+ async def create_bucket (self , ctx : Context | ApplicationContext | Message ) -> Cooldown :
257
257
return self ._cooldown .copy () # type: ignore
258
258
259
259
async def get_bucket (
260
- self , ctx : Context | ApplicationContext , current : float | None = None
260
+ self , ctx : Context | ApplicationContext | Message , current : float | None = None
261
261
) -> Cooldown :
262
262
if self ._type is BucketType .default :
263
263
return self ._cooldown # type: ignore
@@ -274,7 +274,7 @@ async def get_bucket(
274
274
return bucket
275
275
276
276
async def update_rate_limit (
277
- self , ctx : Context | ApplicationContext , current : float | None = None
277
+ self , ctx : Context | ApplicationContext | Message , current : float | None = None
278
278
) -> float | None :
279
279
bucket = await self .get_bucket (ctx , current )
280
280
return bucket .update_rate_limit (current )
@@ -284,13 +284,13 @@ class DynamicCooldownMapping(CooldownMapping):
284
284
def __init__ (
285
285
self ,
286
286
factory : Callable [
287
- [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
287
+ [Context | ApplicationContext | Message ], Cooldown | Awaitable [Cooldown ]
288
288
],
289
- type : Callable [[Context | ApplicationContext ], Any ],
289
+ type : Callable [[Context | ApplicationContext | Message ], Any ],
290
290
) -> None :
291
291
super ().__init__ (None , type )
292
292
self ._factory : Callable [
293
- [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
293
+ [Context | ApplicationContext | Message ], Cooldown | Awaitable [Cooldown ]
294
294
] = factory
295
295
296
296
def copy (self ) -> DynamicCooldownMapping :
@@ -302,7 +302,7 @@ def copy(self) -> DynamicCooldownMapping:
302
302
def valid (self ) -> bool :
303
303
return True
304
304
305
- async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
305
+ async def create_bucket (self , ctx : Context | ApplicationContext | Message ) -> Cooldown :
306
306
from ...ext .commands import Context
307
307
308
308
if isinstance (ctx , Context ):
@@ -399,10 +399,10 @@ def __repr__(self) -> str:
399
399
f"<MaxConcurrency per={ self .per !r} number={ self .number } wait={ self .wait } >"
400
400
)
401
401
402
- def get_key (self , ctx : Context | ApplicationContext ) -> Any :
402
+ def get_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
403
403
return self .per .get_key (ctx )
404
404
405
- async def acquire (self , ctx : Context | ApplicationContext ) -> None :
405
+ async def acquire (self , ctx : Context | ApplicationContext | Message ) -> None :
406
406
key = self .get_key (ctx )
407
407
408
408
try :
@@ -414,7 +414,7 @@ async def acquire(self, ctx: Context | ApplicationContext) -> None:
414
414
if not acquired :
415
415
raise MaxConcurrencyReached (self .number , self .per )
416
416
417
- async def release (self , ctx : Context | ApplicationContext ) -> None :
417
+ async def release (self , ctx : Context | ApplicationContext | Message ) -> None :
418
418
# Technically there's no reason for this function to be async
419
419
# But it might be more useful in the future
420
420
key = self .get_key (ctx )
0 commit comments