Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
Expand All @@ -12,6 +12,9 @@
from .tokens import Token
from .utils import get_md5_hash_password

if TYPE_CHECKING:
from .backends import RawToken

AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES

if not isinstance(api_settings.AUTH_HEADER_TYPES, (list, tuple)):
Expand Down Expand Up @@ -92,7 +95,7 @@ def get_raw_token(self, header: bytes) -> Optional[bytes]:

return parts[1]

def get_validated_token(self, raw_token: bytes) -> Token:
def get_validated_token(self, raw_token: "RawToken") -> Token:
"""
Validates an encoded JSON web token and returns a validated token
wrapper object.
Expand Down
9 changes: 6 additions & 3 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)

from .exceptions import TokenBackendError, TokenBackendExpiredToken
from .tokens import Token
from .utils import format_lazy

try:
Expand All @@ -24,6 +23,8 @@
except ImportError:
JWK_CLIENT_AVAILABLE = False

RawToken = Union[bytes, str]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm i wish there was a way to customize the typing based on the version of JWT.

Specifically, I wonder if we can import something consistent from PyJWT itself that represents the token. Using a union is also not entirely correct (better than what it was before for sure).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a union is also not entirely correct

It is annotated as this exact union in the recent PyJWT versions.

FWIW, bytes | str is also technically supported since 1.7.1 despite the formal jwt: str annotation — see https://github.com/jpadilla/pyjwt/blob/b65e1ac6dc4d11801f3642eaab34ae6a54162c18/jwt/api_jws.py#L171-L177


ALLOWED_ALGORITHMS = {
"HS256",
"HS384",
Expand Down Expand Up @@ -114,12 +115,14 @@ def get_leeway(self) -> timedelta:
)
)

def get_verifying_key(self, token: Token) -> Any:
def get_verifying_key(self, token: RawToken) -> Any:
if self.algorithm.startswith("HS"):
return self.prepared_signing_key

if self.jwks_client:
try:
if isinstance(token, bytes):
token = token.decode("utf-8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do both v1.7.1 and v2.x versions of JWT expect the token to be string? I can understand v2, but how about v1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.7.1 does not implement a PyJWKClient, forcing self.jwks_client to always be None.

Annoyingly, 2.x PyJWKClient.get_signing_key_from_jwt has an incorrect annotation and can actually handle the union type as it wraps code referenced in my other comment. To avoid an extra decode call, I'd just cast it to str until the upstream issue is resolved.

return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as e:
raise TokenBackendError(_("Token is invalid")) from e
Expand Down Expand Up @@ -148,7 +151,7 @@ def encode(self, payload: dict[str, Any]) -> str:
# For PyJWT >= 2.0.0a1
return token

def decode(self, token: Token, verify: bool = True) -> dict[str, Any]:
def decode(self, token: RawToken, verify: bool = True) -> dict[str, Any]:
"""
Performs a validation of the given token and returns its payload
dictionary.
Expand Down
14 changes: 13 additions & 1 deletion rest_framework_simplejwt/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type

from django.conf import settings
from django.test.signals import setting_changed
Expand All @@ -8,6 +8,13 @@

from .utils import format_lazy

if TYPE_CHECKING:
import json

from .authentication import AuthUser
from .models import TokenUser
from .tokens import Token

USER_SETTINGS = getattr(settings, "SIMPLE_JWT", None)

DEFAULTS = {
Expand Down Expand Up @@ -63,6 +70,11 @@


class APISettings(_APISettings): # pragma: no cover
AUTH_TOKEN_CLASSES: "Tuple[Type[Token], ...]"
JSON_ENCODER: "Optional[Type[json.JSONEncoder]]"
TOKEN_USER_CLASS: "Type[TokenUser]"
USER_AUTHENTICATION_RULE: "Callable[[AuthUser], bool]"

def __check_user_settings(self, user_settings: dict[str, Any]) -> dict[str, Any]:
SETTINGS_DOC = "https://django-rest-framework-simplejwt.readthedocs.io/en/latest/settings.html"

Expand Down
8 changes: 6 additions & 2 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

if TYPE_CHECKING:
from .backends import TokenBackend
from .backends import RawToken, TokenBackend

T = TypeVar("T", bound="Token")

Expand All @@ -43,7 +43,11 @@ class Token:
token_type: Optional[str] = None
lifetime: Optional[timedelta] = None

def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None:
def __init__(
self,
token: Optional["RawToken"] = None,
verify: bool = True,
) -> None:
"""
!!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error
message if the given token is invalid, expired, or otherwise not safe
Expand Down
Loading