diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index 239d242a2..b889bc640 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractBaseUser @@ -12,6 +12,9 @@ from .tokens import Token from .utils import get_md5_hash_password +if TYPE_CHECKING: + from .backends import RawToken + AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES if not isinstance(api_settings.AUTH_HEADER_TYPES, (list, tuple)): @@ -92,7 +95,7 @@ def get_raw_token(self, header: bytes) -> Optional[bytes]: return parts[1] - def get_validated_token(self, raw_token: bytes) -> Token: + def get_validated_token(self, raw_token: "RawToken") -> Token: """ Validates an encoded JSON web token and returns a validated token wrapper object. diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 0967fc188..de9aa07f3 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from datetime import timedelta from functools import cached_property -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import jwt from django.utils.translation import gettext_lazy as _ @@ -14,7 +14,6 @@ ) from .exceptions import TokenBackendError, TokenBackendExpiredToken -from .tokens import Token from .utils import format_lazy try: @@ -24,6 +23,8 @@ except ImportError: JWK_CLIENT_AVAILABLE = False +RawToken = Union[bytes, str] + ALLOWED_ALGORITHMS = { "HS256", "HS384", @@ -114,12 +115,15 @@ def get_leeway(self) -> timedelta: ) ) - def get_verifying_key(self, token: Token) -> Any: + def get_verifying_key(self, token: RawToken) -> Any: if self.algorithm.startswith("HS"): return self.prepared_signing_key if self.jwks_client: try: + if isinstance(token, bytes): + # https://github.com/jpadilla/pyjwt/issues/1047 + token = cast(str, token) return self.jwks_client.get_signing_key_from_jwt(token).key except PyJWKClientError as e: raise TokenBackendError(_("Token is invalid")) from e @@ -148,7 +152,7 @@ def encode(self, payload: dict[str, Any]) -> str: # For PyJWT >= 2.0.0a1 return token - def decode(self, token: Token, verify: bool = True) -> dict[str, Any]: + def decode(self, token: RawToken, verify: bool = True) -> dict[str, Any]: """ Performs a validation of the given token and returns its payload dictionary. diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index 518a9ad66..4e17881af 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type from django.conf import settings from django.test.signals import setting_changed @@ -8,6 +8,13 @@ from .utils import format_lazy +if TYPE_CHECKING: + import json + + from .authentication import AuthUser + from .models import TokenUser + from .tokens import Token + USER_SETTINGS = getattr(settings, "SIMPLE_JWT", None) DEFAULTS = { @@ -63,6 +70,11 @@ class APISettings(_APISettings): # pragma: no cover + AUTH_TOKEN_CLASSES: "Tuple[Type[Token], ...]" + JSON_ENCODER: "Optional[Type[json.JSONEncoder]]" + TOKEN_USER_CLASS: "Type[TokenUser]" + USER_AUTHENTICATION_RULE: "Callable[[AuthUser], bool]" + def __check_user_settings(self, user_settings: dict[str, Any]) -> dict[str, Any]: SETTINGS_DOC = "https://django-rest-framework-simplejwt.readthedocs.io/en/latest/settings.html" diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index d97b3de96..4f911ef3a 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -27,7 +27,7 @@ ) if TYPE_CHECKING: - from .backends import TokenBackend + from .backends import RawToken, TokenBackend T = TypeVar("T", bound="Token") @@ -43,7 +43,11 @@ class Token: token_type: Optional[str] = None lifetime: Optional[timedelta] = None - def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None: + def __init__( + self, + token: Optional["RawToken"] = None, + verify: bool = True, + ) -> None: """ !!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error message if the given token is invalid, expired, or otherwise not safe