diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index f2b12ab7e..bb5fb8570 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -8,8 +8,11 @@ from .settings import api_settings from .tokens import RefreshToken, SlidingToken, UntypedToken +from .authentication import JWTAuthentication +from .utils import datetime_from_epoch + if api_settings.BLACKLIST_AFTER_ROTATION: - from .token_blacklist.models import BlacklistedToken + from .token_blacklist.models import BlacklistedToken, OutstandingToken class PasswordField(serializers.CharField): @@ -57,7 +60,8 @@ def validate(self, attrs): @classmethod def get_token(cls, user): - raise NotImplementedError('Must implement `get_token` method for `TokenObtainSerializer` subclasses') + raise NotImplementedError( + 'Must implement `get_token` method for `TokenObtainSerializer` subclasses') class TokenObtainPairSerializer(TokenObtainSerializer): @@ -104,9 +108,12 @@ class TokenRefreshSerializer(serializers.Serializer): def validate(self, attrs): refresh = RefreshToken(attrs['refresh']) - data = {'access': str(refresh.access_token)} + data = {} if api_settings.ROTATE_REFRESH_TOKENS: + auth = JWTAuthentication() + user = auth.get_user(validated_token=refresh) + if api_settings.BLACKLIST_AFTER_ROTATION: try: # Attempt to blacklist the given refresh token @@ -120,8 +127,21 @@ def validate(self, attrs): refresh.set_exp() refresh.set_iat() + OutstandingToken.objects.create( + user=user, + jti=refresh[api_settings.JTI_CLAIM], + token=str(refresh), + created_at=refresh.current_time, + expires_at=datetime_from_epoch(refresh['exp']) + ) + data['refresh'] = str(refresh) + data['access'] = str(refresh.access_token) + + else: + data['access'] = str(refresh.access_token) + return data diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 75ff573ea..c4d7020ec 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -165,7 +165,8 @@ def check_exp(self, claim='exp', current_time=None): claim_time = datetime_from_epoch(claim_value) if claim_time <= current_time: - raise TokenError(format_lazy(_("Token '{}' claim has expired"), claim)) + raise TokenError(format_lazy( + _("Token '{}' claim has expired"), claim)) @classmethod def for_user(cls, user): @@ -181,9 +182,9 @@ def for_user(cls, user): token[api_settings.USER_ID_CLAIM] = user_id return token - + _token_backend = None - + def get_token_backend(self): if self._token_backend is None: self._token_backend = import_string( @@ -223,13 +224,9 @@ def blacklist(self): jti = self.payload[api_settings.JTI_CLAIM] exp = self.payload['exp'] - # Ensure outstanding token exists with given jti - token, _ = OutstandingToken.objects.get_or_create( + # Outstanding token will always exist + token = OutstandingToken.objects.get( jti=jti, - defaults={ - 'token': str(self), - 'expires_at': datetime_from_epoch(exp), - }, ) return BlacklistedToken.objects.get_or_create(token=token)