diff --git a/ninja_jwt/tokens.py b/ninja_jwt/tokens.py index df3259915..2b7caf375 100644 --- a/ninja_jwt/tokens.py +++ b/ninja_jwt/tokens.py @@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _ from .exceptions import TokenBackendError, TokenError -from .settings import api_settings +from .settings import NinjaJWTSettings, api_settings from .token_blacklist.models import BlacklistedToken, OutstandingToken from .utils import aware_utcnow, datetime_from_epoch, datetime_to_epoch, format_lazy @@ -20,7 +20,8 @@ class Token: """ token_type: Optional[str] = None - lifetime: Optional[datetime] = None + lifetime: Optional[timedelta] = None + settings: NinjaJWTSettings = api_settings def __init__(self, token: Optional[Any] = None, verify: bool = True) -> None: """ @@ -49,7 +50,7 @@ def __init__(self, token: Optional[Any] = None, verify: bool = True) -> None: self.verify() else: # New token. Skip all the verification steps. - self.payload = {api_settings.TOKEN_TYPE_CLAIM: self.token_type} + self.payload = {self.settings.TOKEN_TYPE_CLAIM: self.token_type} # Set "exp" and "iat" claims with default value self.set_exp(from_time=self.current_time, lifetime=self.lifetime) @@ -73,7 +74,7 @@ def __delitem__(self, key: str): def __contains__(self, key: str): return key in self.payload - def get(self, key: str, default=None): + def get(self, key: str, default: Any = None): return self.payload.get(key, default) def __str__(self) -> str: @@ -97,13 +98,10 @@ def verify(self) -> Any: # If the defaults are not None then we should enforce the # requirement of these settings.As above, the spec labels # these as optional. - if ( - api_settings.JTI_CLAIM is not None - and api_settings.JTI_CLAIM not in self.payload - ): + if self.settings.JTI_CLAIM is not None and self.settings.JTI_CLAIM not in self.payload: raise TokenError(_("Token has no id")) - if api_settings.TOKEN_TYPE_CLAIM is not None: + if self.settings.TOKEN_TYPE_CLAIM is not None: self.verify_token_type() def verify_token_type(self) -> Any: @@ -111,7 +109,7 @@ def verify_token_type(self) -> Any: Ensures that the token type claim is present and has the correct value. """ try: - token_type = self.payload[api_settings.TOKEN_TYPE_CLAIM] + token_type = self.payload[self.settings.TOKEN_TYPE_CLAIM] except KeyError as e: raise TokenError(_("Token has no type")) from e @@ -127,13 +125,13 @@ def set_jti(self) -> None: See here: https://tools.ietf.org/html/rfc7519#section-4.1.7 """ - self.payload[api_settings.JTI_CLAIM] = uuid4().hex + self.payload[self.settings.JTI_CLAIM] = uuid4().hex def set_exp( self, claim: str = "exp", from_time: Optional[datetime] = None, - lifetime: Optional[datetime] = None, + lifetime: Optional[timedelta] = None, ) -> None: """ Updates the expiration time of a token. @@ -186,12 +184,12 @@ def for_user(cls, user: AbstractBaseUser) -> "Token": Returns an authorization token for the given user that will be provided after authenticating the user's credentials. """ - user_id = getattr(user, api_settings.USER_ID_FIELD) + user_id = getattr(user, cls.settings.USER_ID_FIELD) if not isinstance(user_id, int): user_id = str(user_id) token = cls() - token[api_settings.USER_ID_CLAIM] = user_id + token[cls.settings.USER_ID_CLAIM] = user_id return token @@ -228,7 +226,7 @@ def check_blacklist(self) -> None: Checks if this token is present in the token blacklist. Raises `TokenError` if so. """ - jti = self.payload[api_settings.JTI_CLAIM] + jti = self.payload[self.settings.JTI_CLAIM] if BlacklistedToken.objects.filter(token__jti=jti).exists(): raise TokenError(_("Token is blacklisted")) @@ -238,7 +236,7 @@ def blacklist(self) -> BlacklistedToken: Ensures this token is included in the outstanding token list and adds it to the blacklist. """ - jti = self.payload[api_settings.JTI_CLAIM] + jti = self.payload[self.settings.JTI_CLAIM] exp = self.payload["exp"] # Ensure outstanding token exists with given jti @@ -259,7 +257,7 @@ def for_user(cls, user: "AbstractBaseUser") -> Token: """ token = super().for_user(user) - jti = token[api_settings.JTI_CLAIM] + jti = token[cls.settings.JTI_CLAIM] exp = token["exp"] OutstandingToken.objects.create( @@ -275,7 +273,6 @@ def for_user(cls, user: "AbstractBaseUser") -> Token: class SlidingToken(BlacklistMixin, Token): token_type: str = "sliding" - lifetime: timedelta = api_settings.SLIDING_TOKEN_LIFETIME def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -283,32 +280,46 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if self.token is None: # Set sliding refresh expiration claim if new token self.set_exp( - api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM, + self.settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM, from_time=self.current_time, - lifetime=api_settings.SLIDING_TOKEN_REFRESH_LIFETIME, + lifetime=self.settings.SLIDING_TOKEN_REFRESH_LIFETIME, ) + @property + def lifetime(self) -> timedelta: + return self.settings.SLIDING_TOKEN_LIFETIME + class AccessToken(Token): token_type: str = "access" - lifetime: timedelta = api_settings.ACCESS_TOKEN_LIFETIME + + @property + def lifetime(self) -> timedelta: + return self.settings.ACCESS_TOKEN_LIFETIME class RefreshToken(BlacklistMixin, Token): token_type: str = "refresh" - lifetime: timedelta = api_settings.REFRESH_TOKEN_LIFETIME - no_copy_claims: Tuple = ( - api_settings.TOKEN_TYPE_CLAIM, - "exp", - # Both of these claims are included even though they may be the same. - # It seems possible that a third party token might have a custom or - # namespaced JTI claim as well as a default "jti" claim. In that case, - # we wouldn't want to copy either one. - api_settings.JTI_CLAIM, - "jti", - ) + access_token_class = AccessToken + @property + def lifetime(self) -> timedelta: + return self.settings.REFRESH_TOKEN_LIFETIME + + @property + def no_copy_claims(self) -> Tuple: + return ( + self.settings.TOKEN_TYPE_CLAIM, + "exp", + # Both of these claims are included even though they may be the same. + # It seems possible that a third party token might have a custom or + # namespaced JTI claim as well as a default "jti" claim. In that case, + # we wouldn't want to copy either one. + self.settings.JTI_CLAIM, + "jti", + ) + @property def access_token(self) -> "AccessToken": """