|
| 1 | +from unittest.mock import patch |
| 2 | + |
| 3 | +from django.contrib import messages |
| 4 | +from django.contrib.auth.models import AnonymousUser |
| 5 | +from django.contrib.messages.storage.fallback import FallbackStorage |
| 6 | +from django.contrib.sessions.middleware import SessionMiddleware |
| 7 | +from django.http import HttpResponse |
| 8 | +from django.test import RequestFactory, override_settings |
| 9 | +from requests.exceptions import ConnectionError as RequestsConnectionError |
| 10 | +from social_core.exceptions import AuthCanceled, AuthFailed, AuthForbidden |
| 11 | + |
| 12 | +from dojo.middleware import CustomSocialAuthExceptionMiddleware |
| 13 | + |
| 14 | +from .dojo_test_case import DojoTestCase |
| 15 | + |
| 16 | + |
| 17 | +class TestSocialAuthMiddlewareUnit(DojoTestCase): |
| 18 | + |
| 19 | + """ |
| 20 | + Unit tests: |
| 21 | + Directly test CustomSocialAuthExceptionMiddleware behavior |
| 22 | + by simulating exceptions (ConnectionError, AuthCanceled, AuthFailed, AuthForbidden), |
| 23 | + without relying on actual backend configuration or whether the |
| 24 | + /complete/<backend>/ URLs are registered and accessible. |
| 25 | + """ |
| 26 | + |
| 27 | + def setUp(self): |
| 28 | + self.factory = RequestFactory() |
| 29 | + self.middleware = CustomSocialAuthExceptionMiddleware(lambda *_: HttpResponse("OK")) |
| 30 | + |
| 31 | + def _prepare_request(self, path): |
| 32 | + request = self.factory.get(path) |
| 33 | + request.user = AnonymousUser() |
| 34 | + SessionMiddleware(lambda *_: None).process_request(request) |
| 35 | + request.session.save() |
| 36 | + request._messages = FallbackStorage(request) |
| 37 | + return request |
| 38 | + |
| 39 | + def test_social_auth_exception_redirects_to_login(self): |
| 40 | + login_paths = [ |
| 41 | + "/login/oidc/", |
| 42 | + "/login/auth0/", |
| 43 | + "/login/google-oauth2/", |
| 44 | + "/login/okta-oauth2/", |
| 45 | + "/login/azuread-tenant-oauth2/", |
| 46 | + "/login/gitlab/", |
| 47 | + "/login/keycloak-oauth2/", |
| 48 | + "/login/github/", |
| 49 | + ] |
| 50 | + exceptions = [ |
| 51 | + (RequestsConnectionError("Host unreachable"), "Please use the standard login below."), |
| 52 | + (AuthCanceled("User canceled login"), "Social login was canceled. Please try again or use the standard login."), |
| 53 | + (AuthFailed("Token exchange failed"), "Social login failed. Please try again or use the standard login."), |
| 54 | + (AuthForbidden("User not allowed"), "You are not authorized to log in via this method. Please contact support or use the standard login."), |
| 55 | + ] |
| 56 | + for path in login_paths: |
| 57 | + for exception, expected_message in exceptions: |
| 58 | + with self.subTest(path=path, exception=type(exception).__name__): |
| 59 | + request = self._prepare_request(path) |
| 60 | + response = self.middleware.process_exception(request, exception) |
| 61 | + self.assertEqual(response.status_code, 302) |
| 62 | + self.assertEqual(response.url, "/login?force_login_form") |
| 63 | + storage = list(messages.get_messages(request)) |
| 64 | + self.assertTrue(any(expected_message in str(msg) for msg in storage)) |
| 65 | + |
| 66 | + def test_non_social_auth_path_still_redirects_on_auth_exception(self): |
| 67 | + """Ensure middleware handles AuthFailed even on unrelated paths.""" |
| 68 | + request = self._prepare_request("/some/other/path/") |
| 69 | + exception = AuthFailed("Should be handled globally") |
| 70 | + response = self.middleware.process_exception(request, exception) |
| 71 | + self.assertEqual(response.status_code, 302) |
| 72 | + self.assertEqual(response.url, "/login?force_login_form") |
| 73 | + storage = list(messages.get_messages(request)) |
| 74 | + self.assertTrue(any("Social login failed. Please try again or use the standard login." in str(msg) for msg in storage)) |
| 75 | + |
| 76 | + def test_non_social_auth_path_redirects_on_auth_forbidden(self): |
| 77 | + """Ensure middleware handles AuthForbidden even on unrelated paths.""" |
| 78 | + request = self._prepare_request("/some/other/path/") |
| 79 | + exception = AuthForbidden("User not allowed") |
| 80 | + response = self.middleware.process_exception(request, exception) |
| 81 | + self.assertEqual(response.status_code, 302) |
| 82 | + self.assertEqual(response.url, "/login?force_login_form") |
| 83 | + storage = list(messages.get_messages(request)) |
| 84 | + self.assertTrue(any("You are not authorized to log in via this method." in str(msg) for msg in storage)) |
| 85 | + |
| 86 | + |
| 87 | +@override_settings( |
| 88 | + AUTHENTICATION_BACKENDS=( |
| 89 | + "social_core.backends.github.GithubOAuth2", |
| 90 | + "social_core.backends.gitlab.GitLabOAuth2", |
| 91 | + "social_core.backends.keycloak.KeycloakOAuth2", |
| 92 | + "social_core.backends.azuread_tenant.AzureADTenantOAuth2", |
| 93 | + "social_core.backends.auth0.Auth0OAuth2", |
| 94 | + "social_core.backends.okta.OktaOAuth2", |
| 95 | + "social_core.backends.open_id_connect.OpenIdConnectAuth", |
| 96 | + "django.contrib.auth.backends.ModelBackend", |
| 97 | + ), |
| 98 | +) |
| 99 | +class TestSocialAuthIntegrationFailures(DojoTestCase): |
| 100 | + |
| 101 | + """ |
| 102 | + Integration tests: |
| 103 | + Simulate social login failures by calling /complete/<backend>/ URLs |
| 104 | + and mocking auth_complete() to raise AuthFailed, AuthCanceled, and AuthForbidden. |
| 105 | + Verifies that the middleware is correctly integrated and handles backend failures. |
| 106 | + """ |
| 107 | + |
| 108 | + BACKEND_CLASS_PATHS = { |
| 109 | + "github": "social_core.backends.github.GithubOAuth2", |
| 110 | + "gitlab": "social_core.backends.gitlab.GitLabOAuth2", |
| 111 | + "keycloak": "social_core.backends.keycloak.KeycloakOAuth2", |
| 112 | + "azuread-tenant-oauth2": "social_core.backends.azuread_tenant.AzureADTenantOAuth2", |
| 113 | + "auth0": "social_core.backends.auth0.Auth0OAuth2", |
| 114 | + "okta-oauth2": "social_core.backends.okta.OktaOAuth2", |
| 115 | + "oidc": "social_core.backends.open_id_connect.OpenIdConnectAuth", |
| 116 | + } |
| 117 | + |
| 118 | + def _test_backend_exception(self, backend_slug, exception, expected_message): |
| 119 | + backend_class_path = self.BACKEND_CLASS_PATHS[backend_slug] |
| 120 | + with patch(f"{backend_class_path}.auth_complete", side_effect=exception): |
| 121 | + response = self.client.get(f"/complete/{backend_slug}/", follow=True) |
| 122 | + self.assertEqual(response.status_code, 200) |
| 123 | + self.assertContains(response, expected_message) |
| 124 | + |
| 125 | + def test_all_backends_auth_failed(self): |
| 126 | + for backend in self.BACKEND_CLASS_PATHS: |
| 127 | + with self.subTest(backend=backend): |
| 128 | + self._test_backend_exception(backend, AuthFailed(backend=None), "Social login failed. Please try again or use the standard login.") |
| 129 | + |
| 130 | + def test_all_backends_auth_canceled(self): |
| 131 | + for backend in self.BACKEND_CLASS_PATHS: |
| 132 | + with self.subTest(backend=backend): |
| 133 | + self._test_backend_exception(backend, AuthCanceled(backend=None), "Social login was canceled. Please try again or use the standard login.") |
| 134 | + |
| 135 | + def test_all_backends_auth_forbidden(self): |
| 136 | + for backend in self.BACKEND_CLASS_PATHS: |
| 137 | + with self.subTest(backend=backend): |
| 138 | + self._test_backend_exception( |
| 139 | + backend, |
| 140 | + AuthForbidden(backend=None), |
| 141 | + "You are not authorized to log in via this method. Please contact support or use the standard login.", |
| 142 | + ) |
0 commit comments