Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions ninja_jwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.utils.translation import gettext_lazy as _

from .exceptions import TokenBackendError, TokenError
from .settings import api_settings
from .settings import NinjaJWTSettings, api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
from .utils import aware_utcnow, datetime_from_epoch, datetime_to_epoch, format_lazy

Expand All @@ -20,7 +20,8 @@ class Token:
"""

token_type: Optional[str] = None
lifetime: Optional[datetime] = None
lifetime: Optional[timedelta] = None
settings: NinjaJWTSettings = api_settings
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

settings is better off a property than a cls attribute

@property
def ninja_jwt_setting(self) -> NinjaJWTSettings:
    return api_settings


def __init__(self, token: Optional[Any] = None, verify: bool = True) -> None:
"""
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(self, token: Optional[Any] = None, verify: bool = True) -> None:
self.verify()
else:
# New token. Skip all the verification steps.
self.payload = {api_settings.TOKEN_TYPE_CLAIM: self.token_type}
self.payload = {self.settings.TOKEN_TYPE_CLAIM: self.token_type}

# Set "exp" and "iat" claims with default value
self.set_exp(from_time=self.current_time, lifetime=self.lifetime)
Expand All @@ -73,7 +74,7 @@ def __delitem__(self, key: str):
def __contains__(self, key: str):
return key in self.payload

def get(self, key: str, default=None):
def get(self, key: str, default: Any = None):
return self.payload.get(key, default)

def __str__(self) -> str:
Expand All @@ -97,21 +98,18 @@ def verify(self) -> Any:
# If the defaults are not None then we should enforce the
# requirement of these settings.As above, the spec labels
# these as optional.
if (
api_settings.JTI_CLAIM is not None
and api_settings.JTI_CLAIM not in self.payload
):
if self.settings.JTI_CLAIM is not None and self.settings.JTI_CLAIM not in self.payload:
raise TokenError(_("Token has no id"))

if api_settings.TOKEN_TYPE_CLAIM is not None:
if self.settings.TOKEN_TYPE_CLAIM is not None:
self.verify_token_type()

def verify_token_type(self) -> Any:
"""
Ensures that the token type claim is present and has the correct value.
"""
try:
token_type = self.payload[api_settings.TOKEN_TYPE_CLAIM]
token_type = self.payload[self.settings.TOKEN_TYPE_CLAIM]
except KeyError as e:
raise TokenError(_("Token has no type")) from e

Expand All @@ -127,13 +125,13 @@ def set_jti(self) -> None:
See here:
https://tools.ietf.org/html/rfc7519#section-4.1.7
"""
self.payload[api_settings.JTI_CLAIM] = uuid4().hex
self.payload[self.settings.JTI_CLAIM] = uuid4().hex

def set_exp(
self,
claim: str = "exp",
from_time: Optional[datetime] = None,
lifetime: Optional[datetime] = None,
lifetime: Optional[timedelta] = None,
) -> None:
"""
Updates the expiration time of a token.
Expand Down Expand Up @@ -186,12 +184,12 @@ def for_user(cls, user: AbstractBaseUser) -> "Token":
Returns an authorization token for the given user that will be provided
after authenticating the user's credentials.
"""
user_id = getattr(user, api_settings.USER_ID_FIELD)
user_id = getattr(user, cls.settings.USER_ID_FIELD)
if not isinstance(user_id, int):
user_id = str(user_id)

token = cls()
token[api_settings.USER_ID_CLAIM] = user_id
token[cls.settings.USER_ID_CLAIM] = user_id

return token

Expand Down Expand Up @@ -228,7 +226,7 @@ def check_blacklist(self) -> None:
Checks if this token is present in the token blacklist. Raises
`TokenError` if so.
"""
jti = self.payload[api_settings.JTI_CLAIM]
jti = self.payload[self.settings.JTI_CLAIM]

if BlacklistedToken.objects.filter(token__jti=jti).exists():
raise TokenError(_("Token is blacklisted"))
Expand All @@ -238,7 +236,7 @@ def blacklist(self) -> BlacklistedToken:
Ensures this token is included in the outstanding token list and
adds it to the blacklist.
"""
jti = self.payload[api_settings.JTI_CLAIM]
jti = self.payload[self.settings.JTI_CLAIM]
exp = self.payload["exp"]

# Ensure outstanding token exists with given jti
Expand All @@ -259,7 +257,7 @@ def for_user(cls, user: "AbstractBaseUser") -> Token:
"""
token = super().for_user(user)

jti = token[api_settings.JTI_CLAIM]
jti = token[cls.settings.JTI_CLAIM]
exp = token["exp"]

OutstandingToken.objects.create(
Expand All @@ -275,40 +273,53 @@ def for_user(cls, user: "AbstractBaseUser") -> Token:

class SlidingToken(BlacklistMixin, Token):
token_type: str = "sliding"
lifetime: timedelta = api_settings.SLIDING_TOKEN_LIFETIME

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

if self.token is None:
# Set sliding refresh expiration claim if new token
self.set_exp(
api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM,
self.settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM,
from_time=self.current_time,
lifetime=api_settings.SLIDING_TOKEN_REFRESH_LIFETIME,
lifetime=self.settings.SLIDING_TOKEN_REFRESH_LIFETIME,
)

@property
def lifetime(self) -> timedelta:
return self.settings.SLIDING_TOKEN_LIFETIME


class AccessToken(Token):
token_type: str = "access"
lifetime: timedelta = api_settings.ACCESS_TOKEN_LIFETIME

@property
def lifetime(self) -> timedelta:
return self.settings.ACCESS_TOKEN_LIFETIME


class RefreshToken(BlacklistMixin, Token):
token_type: str = "refresh"
lifetime: timedelta = api_settings.REFRESH_TOKEN_LIFETIME
no_copy_claims: Tuple = (
api_settings.TOKEN_TYPE_CLAIM,
"exp",
# Both of these claims are included even though they may be the same.
# It seems possible that a third party token might have a custom or
# namespaced JTI claim as well as a default "jti" claim. In that case,
# we wouldn't want to copy either one.
api_settings.JTI_CLAIM,
"jti",
)

access_token_class = AccessToken

@property
def lifetime(self) -> timedelta:
return self.settings.REFRESH_TOKEN_LIFETIME

@property
def no_copy_claims(self) -> Tuple:
return (
self.settings.TOKEN_TYPE_CLAIM,
"exp",
# Both of these claims are included even though they may be the same.
# It seems possible that a third party token might have a custom or
# namespaced JTI claim as well as a default "jti" claim. In that case,
# we wouldn't want to copy either one.
self.settings.JTI_CLAIM,
"jti",
)

@property
def access_token(self) -> "AccessToken":
"""
Expand Down
Loading