From 48df0752b6ae3258992d8feab2ef3de8e40d5adf Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Thu, 13 Mar 2025 15:09:18 +0000 Subject: [PATCH 1/3] fix: Incorrect `Token`, `TokenBackend` typing --- rest_framework_simplejwt/backends.py | 2 +- rest_framework_simplejwt/settings.py | 14 +++++++++++++- rest_framework_simplejwt/tokens.py | 8 ++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 0967fc188..4967a7aaa 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -148,7 +148,7 @@ def encode(self, payload: dict[str, Any]) -> str: # For PyJWT >= 2.0.0a1 return token - def decode(self, token: Token, verify: bool = True) -> dict[str, Any]: + def decode(self, token: Union[bytes, str], verify: bool = True) -> dict[str, Any]: """ Performs a validation of the given token and returns its payload dictionary. diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index 518a9ad66..4e17881af 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type from django.conf import settings from django.test.signals import setting_changed @@ -8,6 +8,13 @@ from .utils import format_lazy +if TYPE_CHECKING: + import json + + from .authentication import AuthUser + from .models import TokenUser + from .tokens import Token + USER_SETTINGS = getattr(settings, "SIMPLE_JWT", None) DEFAULTS = { @@ -63,6 +70,11 @@ class APISettings(_APISettings): # pragma: no cover + AUTH_TOKEN_CLASSES: "Tuple[Type[Token], ...]" + JSON_ENCODER: "Optional[Type[json.JSONEncoder]]" + TOKEN_USER_CLASS: "Type[TokenUser]" + USER_AUTHENTICATION_RULE: "Callable[[AuthUser], bool]" + def __check_user_settings(self, user_settings: dict[str, Any]) -> dict[str, Any]: SETTINGS_DOC = "https://django-rest-framework-simplejwt.readthedocs.io/en/latest/settings.html" diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index d97b3de96..d8d835698 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union from uuid import uuid4 from django.conf import settings @@ -43,7 +43,11 @@ class Token: token_type: Optional[str] = None lifetime: Optional[timedelta] = None - def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None: + def __init__( + self, + token: Union[None, bytes, str] = None, + verify: bool = True, + ) -> None: """ !!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error message if the given token is invalid, expired, or otherwise not safe From bbd3163c2570baf81d180f7107b5262aff8924df Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Thu, 13 Mar 2025 15:57:10 +0000 Subject: [PATCH 2/3] acknowledge correct types, introduce `RawToken` --- rest_framework_simplejwt/authentication.py | 7 +++++-- rest_framework_simplejwt/backends.py | 9 ++++++--- rest_framework_simplejwt/tokens.py | 6 +++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index 239d242a2..b889bc640 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractBaseUser @@ -12,6 +12,9 @@ from .tokens import Token from .utils import get_md5_hash_password +if TYPE_CHECKING: + from .backends import RawToken + AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES if not isinstance(api_settings.AUTH_HEADER_TYPES, (list, tuple)): @@ -92,7 +95,7 @@ def get_raw_token(self, header: bytes) -> Optional[bytes]: return parts[1] - def get_validated_token(self, raw_token: bytes) -> Token: + def get_validated_token(self, raw_token: "RawToken") -> Token: """ Validates an encoded JSON web token and returns a validated token wrapper object. diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 4967a7aaa..fc86f486c 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -14,7 +14,6 @@ ) from .exceptions import TokenBackendError, TokenBackendExpiredToken -from .tokens import Token from .utils import format_lazy try: @@ -24,6 +23,8 @@ except ImportError: JWK_CLIENT_AVAILABLE = False +RawToken = Union[bytes, str] + ALLOWED_ALGORITHMS = { "HS256", "HS384", @@ -114,12 +115,14 @@ def get_leeway(self) -> timedelta: ) ) - def get_verifying_key(self, token: Token) -> Any: + def get_verifying_key(self, token: RawToken) -> Any: if self.algorithm.startswith("HS"): return self.prepared_signing_key if self.jwks_client: try: + if isinstance(token, bytes): + token = token.decode("utf-8") return self.jwks_client.get_signing_key_from_jwt(token).key except PyJWKClientError as e: raise TokenBackendError(_("Token is invalid")) from e @@ -148,7 +151,7 @@ def encode(self, payload: dict[str, Any]) -> str: # For PyJWT >= 2.0.0a1 return token - def decode(self, token: Union[bytes, str], verify: bool = True) -> dict[str, Any]: + def decode(self, token: RawToken, verify: bool = True) -> dict[str, Any]: """ Performs a validation of the given token and returns its payload dictionary. diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index d8d835698..4f911ef3a 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from uuid import uuid4 from django.conf import settings @@ -27,7 +27,7 @@ ) if TYPE_CHECKING: - from .backends import TokenBackend + from .backends import RawToken, TokenBackend T = TypeVar("T", bound="Token") @@ -45,7 +45,7 @@ class Token: def __init__( self, - token: Union[None, bytes, str] = None, + token: Optional["RawToken"] = None, verify: bool = True, ) -> None: """ From 37ac55f9d3d619b3a9634411f9a97e85c437f9b5 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Mon, 17 Mar 2025 15:47:45 +0000 Subject: [PATCH 3/3] avoid extra decode --- rest_framework_simplejwt/backends.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index fc86f486c..de9aa07f3 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from datetime import timedelta from functools import cached_property -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import jwt from django.utils.translation import gettext_lazy as _ @@ -122,7 +122,8 @@ def get_verifying_key(self, token: RawToken) -> Any: if self.jwks_client: try: if isinstance(token, bytes): - token = token.decode("utf-8") + # https://github.com/jpadilla/pyjwt/issues/1047 + token = cast(str, token) return self.jwks_client.get_signing_key_from_jwt(token).key except PyJWKClientError as e: raise TokenBackendError(_("Token is invalid")) from e