diff --git a/config/settings/base.py b/config/settings/base.py index c3de97cce..911ee5f2a 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -578,6 +578,9 @@ "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", "PAGE_SIZE": 100, + "DEFAULT_AUTHENTICATION_CLASSES": [ + "lemarche.api.authentication.CustomBearerAuthentication", + ], } diff --git a/lemarche/api/authentication.py b/lemarche/api/authentication.py new file mode 100644 index 000000000..465df6bec --- /dev/null +++ b/lemarche/api/authentication.py @@ -0,0 +1,92 @@ +import logging + +from rest_framework.authentication import BaseAuthentication +from rest_framework.exceptions import AuthenticationFailed + +from lemarche.users.models import User + + +logger = logging.getLogger(__name__) + + +class CustomBearerAuthentication(BaseAuthentication): + """ + Authentication via: + 1. Authorization header: Bearer (recommended). + 2. URL parameter ?token= (deprecated, temporary support). + """ + + def authenticate(self, request): + token = None + warning_issued = False + + # Priority to the Authorization header + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.split("Bearer ")[1] + elif request.GET.get("token"): # Otherwise, try the URL parameter + token = request.GET.get("token") + warning_issued = True + logger.warning("Authentication via URL token detected. This method is deprecated and less secure.") + + # If no token is provided + if not token: + return None + + # Check the minimum length of the token + if len(token) < 64: + raise AuthenticationFailed("Token too short. Possible security issue detected.") + + if not token.isalnum(): + raise AuthenticationFailed("Token contains invalid characters. Possible security issue detected.") + + # Validate the token + try: + user = User.objects.has_api_key().get(api_key=token) + except User.DoesNotExist: + raise AuthenticationFailed("Invalid or expired token") + + # Add a warning in the response for URL tokens + if warning_issued: + request._deprecated_auth_warning = True # Marker for middleware or view + + # Return the user and the token + return (user, token) + + def authenticate_header(self, request): + """ + Returns the expected header for 401 responses. + """ + return 'Bearer realm="api"' + + +class DeprecationWarningMiddleware: + """ + Middleware to inform users that authentication via URL `?token=` is deprecated. + + This middleware checks if the request contains a deprecated authentication token + and adds a warning header to the response if it does. + + Attributes: + get_response (callable): The next middleware or view in the chain. + + Methods: + __call__(request): + Processes the request and adds a deprecation warning header to the response + if the request contains a deprecated authentication token. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + response = self.get_response(request) + + # Add a warning if the marker is set in the request + if hasattr(request, "_deprecated_auth_warning") and request._deprecated_auth_warning: + response.headers["Deprecation-Warning"] = ( + "URL token authentication is deprecated and will be removed on 2025/01. " + "Please use Authorization header with Bearer tokens." + ) + + return response diff --git a/lemarche/api/siaes/tests.py b/lemarche/api/siaes/tests.py index 060934db6..8855e2ed2 100644 --- a/lemarche/api/siaes/tests.py +++ b/lemarche/api/siaes/tests.py @@ -1,6 +1,7 @@ from django.test import TestCase from django.urls import reverse +from lemarche.api.utils import generate_random_string from lemarche.networks.factories import NetworkFactory from lemarche.sectors.factories import SectorFactory from lemarche.siaes import constants as siae_constants @@ -14,7 +15,8 @@ class SiaeListApiTest(TestCase): def setUpTestData(cls): for _ in range(12): SiaeFactory() - UserFactory(api_key="admin") + cls.user_token = generate_random_string() + UserFactory(api_key=cls.user_token) def test_should_return_siae_sublist_to_anonymous_users(self): url = reverse("api:siae-list") # anonymous user @@ -32,7 +34,7 @@ def test_should_return_siae_sublist_to_anonymous_users(self): self.assertTrue("created_at" in response.data[0]) def test_should_return_detailed_siae_list_with_pagination_to_authenticated_users(self): - url = reverse("api:siae-list") + "?token=admin" + url = reverse("api:siae-list") + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 12) self.assertEqual(len(response.data["results"]), 12) @@ -81,13 +83,14 @@ def setUpTestData(cls): kind=siae_constants.KIND_EI, presta_type=[siae_constants.PRESTA_DISP], department="01" ) siae_with_network_2.networks.add(cls.network_2) - UserFactory(api_key="admin") + cls.user_token = generate_random_string() + UserFactory(api_key=cls.user_token) def test_siae_count(self): self.assertEqual(Siae.objects.count(), 9) def test_should_return_siae_list(self): - url = reverse("api:siae-list") + "?token=admin" + url = reverse("api:siae-list") + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 4 + 2 + 2) self.assertEqual(len(response.data["results"]), 4 + 2 + 2) @@ -101,68 +104,81 @@ def test_should_not_filter_siae_list_for_anonymous_user(self): self.assertEqual(len(response.data), 4 + 2 + 2) def test_should_filter_siae_list_by_is_active(self): - url = reverse("api:siae-list") + "?is_active=false&token=admin" + url = reverse("api:siae-list") + "?is_active=false&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) - url = reverse("api:siae-list") + "?is_active=true&token=admin" + url = reverse("api:siae-list") + "?is_active=true&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 3 + 2 + 2) self.assertEqual(len(response.data["results"]), 3 + 2 + 2) def test_should_filter_siae_list_by_kind(self): # single - url = reverse("api:siae-list") + f"?kind={siae_constants.KIND_ETTI}&token=admin" + url = reverse("api:siae-list") + f"?kind={siae_constants.KIND_ETTI}&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) # multiple - url = reverse("api:siae-list") + f"?kind={siae_constants.KIND_ETTI}&kind={siae_constants.KIND_ACI}&token=admin" + url = ( + reverse("api:siae-list") + + f"?kind={siae_constants.KIND_ETTI}&kind={siae_constants.KIND_ACI}&token=" + + self.user_token + ) response = self.client.get(url) self.assertEqual(response.data["count"], 1 + 1) self.assertEqual(len(response.data["results"]), 1 + 1) def test_should_filter_siae_list_by_presta_type(self): # single - url = reverse("api:siae-list") + f"?presta_type={siae_constants.PRESTA_BUILD}&token=admin" + url = reverse("api:siae-list") + f"?presta_type={siae_constants.PRESTA_BUILD}&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) # multiple url = ( reverse("api:siae-list") - + f"?presta_type={siae_constants.PRESTA_BUILD}&presta_type={siae_constants.PRESTA_PREST}&token=admin" + + f"?presta_type={siae_constants.PRESTA_BUILD}&presta_type={siae_constants.PRESTA_PREST}&token=" + + self.user_token ) response = self.client.get(url) self.assertEqual(response.data["count"], 1 + 1) self.assertEqual(len(response.data["results"]), 1 + 1) def test_should_filter_siae_list_by_department(self): - url = reverse("api:siae-list") + "?department=38&token=admin" + url = reverse("api:siae-list") + "?department=38&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) def test_should_filter_siae_list_by_sector(self): # single - url = reverse("api:siae-list") + f"?sectors={self.sector_1.slug}&token=admin" + url = reverse("api:siae-list") + f"?sectors={self.sector_1.slug}&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) # multiple - url = reverse("api:siae-list") + f"?sectors={self.sector_1.slug}§ors={self.sector_2.slug}&token=admin" + url = ( + reverse("api:siae-list") + + f"?sectors={self.sector_1.slug}§ors={self.sector_2.slug}&token=" + + self.user_token + ) response = self.client.get(url) self.assertEqual(response.data["count"], 1 + 1) self.assertEqual(len(response.data["results"]), 1 + 1) def test_should_filter_siae_list_by_network(self): # single - url = reverse("api:siae-list") + f"?networks={self.network_1.slug}&token=admin" + url = reverse("api:siae-list") + f"?networks={self.network_1.slug}&token=" + self.user_token response = self.client.get(url) self.assertEqual(response.data["count"], 1) self.assertEqual(len(response.data["results"]), 1) # multiple - url = reverse("api:siae-list") + f"?networks={self.network_1.slug}&networks={self.network_2.slug}&token=admin" + url = ( + reverse("api:siae-list") + + f"?networks={self.network_1.slug}&networks={self.network_2.slug}&token=" + + self.user_token + ) response = self.client.get(url) self.assertEqual(response.data["count"], 1 + 1) self.assertEqual(len(response.data["results"]), 1 + 1) @@ -172,17 +188,17 @@ class SiaeDetailApiTest(TestCase): @classmethod def setUpTestData(cls): cls.siae = SiaeFactory() - UserFactory(api_key="admin") + cls.user_token = generate_random_string() + UserFactory(api_key=cls.user_token) def test_should_return_4O4_if_siae_excluded(self): siae_opcs = SiaeFactory(kind="OPCS") - for siae in [siae_opcs]: - url = reverse("api:siae-detail", args=[siae.id]) # anonymous - response = self.client.get(url) - self.assertEqual(response.status_code, 404) - url = reverse("api:siae-detail", args=[siae.id]) + "?token=admin" - response = self.client.get(url) - self.assertEqual(response.status_code, 404) + url = reverse("api:siae-detail", args=[siae_opcs.id]) # anonymous + response = self.client.get(url) + self.assertEqual(response.status_code, 404) + url = reverse("api:siae-detail", args=[siae_opcs.id]) + "?token=" + self.user_token + response = self.client.get(url) + self.assertEqual(response.status_code, 404) def test_should_return_simple_siae_object_to_anonymous_users(self): url = reverse("api:siae-detail", args=[self.siae.id]) # anonymous user @@ -201,7 +217,7 @@ def test_should_return_simple_siae_object_to_anonymous_users(self): self.assertTrue("labels" not in response.data) def test_should_return_detailed_siae_object_to_authenticated_users(self): - url = reverse("api:siae-detail", args=[self.siae.id]) + "?token=admin" + url = reverse("api:siae-detail", args=[self.siae.id]) + "?token=" + self.user_token response = self.client.get(url) self.assertTrue("id" in response.data) self.assertTrue("name" in response.data) @@ -220,8 +236,9 @@ def test_should_return_detailed_siae_object_to_authenticated_users(self): class SiaeRetrieveBySlugApiTest(TestCase): @classmethod def setUpTestData(cls): + cls.user_token = generate_random_string() SiaeFactory(name="Une structure", siret="12312312312345", department="38") - UserFactory(api_key="admin") + UserFactory(api_key=cls.user_token) def test_should_return_404_if_slug_unknown(self): url = reverse("api:siae-retrieve-by-slug", args=["test-123"]) # anonymous user @@ -230,13 +247,12 @@ def test_should_return_404_if_slug_unknown(self): def test_should_return_4O4_if_siae_excluded(self): siae_opcs = SiaeFactory(kind="OPCS") - for siae in [siae_opcs]: - url = reverse("api:siae-retrieve-by-slug", args=[siae.slug]) # anonymous - response = self.client.get(url) - self.assertEqual(response.status_code, 404) - url = reverse("api:siae-retrieve-by-slug", args=[siae.slug]) + "?token=admin" - response = self.client.get(url) - self.assertEqual(response.status_code, 404) + url = reverse("api:siae-retrieve-by-slug", args=[siae_opcs.slug]) # anonymous + response = self.client.get(url) + self.assertEqual(response.status_code, 404) + url = reverse("api:siae-retrieve-by-slug", args=[siae_opcs.slug]) + "?token=" + self.user_token + response = self.client.get(url) + self.assertEqual(response.status_code, 404) def test_should_return_siae_if_slug_known(self): url = reverse("api:siae-retrieve-by-slug", args=["une-structure-38"]) # anonymous user @@ -248,7 +264,7 @@ def test_should_return_siae_if_slug_known(self): self.assertTrue("sectors" not in response.data) def test_should_return_detailed_siae_object_to_authenticated_user(self): - url = reverse("api:siae-retrieve-by-slug", args=["une-structure-38"]) + "?token=admin" + url = reverse("api:siae-retrieve-by-slug", args=["une-structure-38"]) + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.status_code, 200) self.assertEqual(response.data["siret"], "12312312312345") @@ -262,7 +278,8 @@ def setUpTestData(cls): SiaeFactory(name="Une structure", siret="12312312312345", department="38") SiaeFactory(name="Une autre structure", siret="22222222233333", department="69") SiaeFactory(name="Une autre structure avec le meme siret", siret="22222222233333", department="69") - UserFactory(api_key="admin") + cls.user_token = generate_random_string() + UserFactory(api_key=cls.user_token) def test_should_return_400_if_siren_malformed(self): # anonymous user @@ -281,13 +298,12 @@ def test_should_return_empty_list_if_siren_unknown(self): def test_should_return_4O4_if_siae_excluded(self): siae_opcs = SiaeFactory(kind="OPCS", siret="99999999999999") - for siae in [siae_opcs]: - url = reverse("api:siae-retrieve-by-siren", args=[siae.siren]) # anonymous - response = self.client.get(url) - self.assertEqual(len(response.data), 0) - url = reverse("api:siae-retrieve-by-siren", args=[siae.siren]) + "?token=admin" - response = self.client.get(url) - self.assertEqual(len(response.data), 0) + url = reverse("api:siae-retrieve-by-siren", args=[siae_opcs.siren]) # anonymous + response = self.client.get(url) + self.assertEqual(len(response.data), 0) + url = reverse("api:siae-retrieve-by-siren", args=[siae_opcs.siren]) + "?token=" + self.user_token + response = self.client.get(url) + self.assertEqual(len(response.data), 0) def test_should_return_siae_list_if_siren_known(self): # anonymous user @@ -308,7 +324,7 @@ def test_should_return_siae_list_if_siren_known(self): self.assertEqual(response.data[1]["siret"], "22222222233333") self.assertTrue("sectors" not in response.data[0]) # authenticated user - url = reverse("api:siae-retrieve-by-siren", args=["123123123"]) + "?token=admin" + url = reverse("api:siae-retrieve-by-siren", args=["123123123"]) + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.status_code, 200) # self.assertEqual(type(response.data), list) @@ -316,7 +332,7 @@ def test_should_return_siae_list_if_siren_known(self): self.assertEqual(response.data[0]["siret"], "12312312312345") self.assertEqual(response.data[0]["slug"], "une-structure-38") self.assertTrue("sectors" in response.data[0]) - url = reverse("api:siae-retrieve-by-siren", args=["222222222"]) + "?token=admin" + url = reverse("api:siae-retrieve-by-siren", args=["222222222"]) + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.status_code, 200) # self.assertEqual(type(response.data), list) @@ -332,7 +348,8 @@ def setUpTestData(cls): SiaeFactory(name="Une structure", siret="12312312312345", department="38") SiaeFactory(name="Une autre structure", siret="22222222233333", department="69") SiaeFactory(name="Une autre structure avec le meme siret", siret="22222222233333", department="69") - UserFactory(api_key="admin") + cls.user_token = generate_random_string() + UserFactory(api_key=cls.user_token) def test_should_return_404_if_siret_malformed(self): # anonymous user @@ -351,13 +368,12 @@ def test_should_return_empty_list_if_siret_unknown(self): def test_should_return_4O4_if_siae_excluded(self): siae_opcs = SiaeFactory(kind="OPCS", siret="99999999999999") - for siae in [siae_opcs]: - url = reverse("api:siae-retrieve-by-siret", args=[siae.siret]) # anonymous - response = self.client.get(url) - self.assertEqual(len(response.data), 0) - url = reverse("api:siae-retrieve-by-siret", args=[siae.siret]) + "?token=admin" - response = self.client.get(url) - self.assertEqual(len(response.data), 0) + url = reverse("api:siae-retrieve-by-siret", args=[siae_opcs.siret]) # anonymous + response = self.client.get(url) + self.assertEqual(len(response.data), 0) + url = reverse("api:siae-retrieve-by-siret", args=[siae_opcs.siret]) + "?token=" + self.user_token + response = self.client.get(url) + self.assertEqual(len(response.data), 0) def test_should_return_siae_list_if_siret_known(self): # anonymous user @@ -378,7 +394,7 @@ def test_should_return_siae_list_if_siret_known(self): self.assertEqual(response.data[1]["siret"], "22222222233333") self.assertTrue("sectors" not in response.data[0]) # authenticated user - url = reverse("api:siae-retrieve-by-siret", args=["12312312312345"]) + "?token=admin" + url = reverse("api:siae-retrieve-by-siret", args=["12312312312345"]) + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.status_code, 200) # self.assertEqual(type(response.data), list) @@ -386,7 +402,7 @@ def test_should_return_siae_list_if_siret_known(self): self.assertEqual(response.data[0]["siret"], "12312312312345") self.assertEqual(response.data[0]["slug"], "une-structure-38") self.assertTrue("sectors" in response.data[0]) - url = reverse("api:siae-retrieve-by-siret", args=["22222222233333"]) + "?token=admin" + url = reverse("api:siae-retrieve-by-siret", args=["22222222233333"]) + "?token=" + self.user_token response = self.client.get(url) self.assertEqual(response.status_code, 200) # self.assertEqual(type(response.data), list) diff --git a/lemarche/api/siaes/views.py b/lemarche/api/siaes/views.py index aa9fc3196..718c14721 100644 --- a/lemarche/api/siaes/views.py +++ b/lemarche/api/siaes/views.py @@ -6,7 +6,7 @@ from lemarche.api.siaes.filters import SiaeFilter from lemarche.api.siaes.serializers import SiaeDetailSerializer, SiaeListSerializer -from lemarche.api.utils import BasicChoiceSerializer, BasicChoiceWithParentSerializer, check_user_token +from lemarche.api.utils import BasicChoiceSerializer, BasicChoiceWithParentSerializer from lemarche.siaes import constants as siae_constants from lemarche.siaes.models import Siae @@ -24,7 +24,9 @@ class SiaeViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.Gen summary="Lister toutes les structures", tags=[Siae._meta.verbose_name_plural], parameters=[ - OpenApiParameter(name="token", description="Token Utilisateur", required=False, type=str), + OpenApiParameter( + name="token", description="Token Utilisateur (pour compatibilité ancienne)", required=False, type=str + ), ], ) def list(self, request, format=None): @@ -33,23 +35,24 @@ def list(self, request, format=None): Un token est nécessaire pour l'accès complet à cette ressource. """ - if request.method == "GET": - token = request.GET.get("token", None) - if not token: - serializer = SiaeListSerializer( - self.get_queryset()[:10], - many=True, - ) - return Response(serializer.data) - else: - check_user_token(token) - return super().list(request, format) + if request.user.is_authenticated: + # Utilisateur authentifié : accès complet + return super().list(request, format) + else: + # Utilisateur non authentifié : limiter à 10 résultats + serializer = SiaeListSerializer( + self.get_queryset()[:10], + many=True, + ) + return Response(serializer.data) @extend_schema( summary="Détail d'une structure (par son id)", tags=[Siae._meta.verbose_name_plural], parameters=[ - OpenApiParameter(name="token", description="Token Utilisateur", required=False, type=str), + OpenApiParameter( + name="token", description="Token Utilisateur (pour compatibilité ancienne)", required=False, type=str + ), ], responses=SiaeDetailSerializer, ) @@ -65,7 +68,9 @@ def retrieve(self, request, pk=None, format=None): summary="Détail d'une structure (par son slug)", tags=[Siae._meta.verbose_name_plural], parameters=[ - OpenApiParameter(name="token", description="Token Utilisateur", required=False, type=str), + OpenApiParameter( + name="token", description="Token Utilisateur (pour compatibilité ancienne)", required=False, type=str + ), ], responses=SiaeDetailSerializer, ) @@ -82,7 +87,9 @@ def retrieve_by_slug(self, request, slug=None, format=None): summary="Détail d'une structure (par son siren)", tags=[Siae._meta.verbose_name_plural], parameters=[ - OpenApiParameter(name="token", description="Token Utilisateur", required=False, type=str), + OpenApiParameter( + name="token", description="Token Utilisateur (pour compatibilité ancienne)", required=False, type=str + ), ], responses=SiaeDetailSerializer, ) @@ -100,7 +107,9 @@ def retrieve_by_siren(self, request, siren=None, format=None): summary="Détail d'une structure (par son siret)", tags=[Siae._meta.verbose_name_plural], parameters=[ - OpenApiParameter(name="token", description="Token Utilisateur", required=False, type=str), + OpenApiParameter( + name="token", description="Token Utilisateur (pour compatibilité ancienne)", required=False, type=str + ), ], responses=SiaeDetailSerializer, ) @@ -115,14 +124,12 @@ def retrieve_by_siret(self, request, siret=None, format=None): return self._list_return(request, queryset, format) def _retrieve_return(self, request, queryset, format): - token = request.GET.get("token", None) - if not token: + if not request.user.is_authenticated: serializer = SiaeListSerializer( queryset, many=False, ) else: - check_user_token(token) serializer = SiaeDetailSerializer( queryset, many=False, @@ -130,14 +137,12 @@ def _retrieve_return(self, request, queryset, format): return Response(serializer.data) def _list_return(self, request, queryset, format): - token = request.GET.get("token", None) - if not token: + if not request.user.is_authenticated: serializer = SiaeListSerializer( queryset, many=True, ) else: - check_user_token(token) serializer = SiaeDetailSerializer( queryset, many=True, diff --git a/lemarche/api/tenders/tests.py b/lemarche/api/tenders/tests.py index c345cdfb5..c11fdc66c 100644 --- a/lemarche/api/tenders/tests.py +++ b/lemarche/api/tenders/tests.py @@ -4,6 +4,7 @@ from django.test import TestCase from django.urls import reverse +from lemarche.api.utils import generate_random_string from lemarche.perimeters.factories import PerimeterFactory from lemarche.sectors.factories import SectorFactory from lemarche.tenders import constants as tender_constants @@ -45,10 +46,11 @@ class TenderCreateApiTest(TestCase): @classmethod def setUpTestData(cls): - cls.url = reverse("api:tenders-list") + "?token=admin" + cls.user_token = generate_random_string() + cls.url = reverse("api:tenders-list") + "?token=" + cls.user_token cls.user = UserFactory() cls.user_buyer = UserFactory(kind=User.KIND_BUYER, company_name="Entreprise Buyer") - cls.user_with_token = UserFactory(email="admin@example.com", api_key="admin") + cls.user_with_token = UserFactory(email="admin@example.com", api_key=cls.user_token) cls.perimeter = PerimeterFactory() cls.sector_1 = SectorFactory() cls.sector_2 = SectorFactory() @@ -256,8 +258,9 @@ def test_create_tender_with_distance_location(self): class TenderCreateApiPartnerTest(TestCase): @classmethod def setUpTestData(cls): - cls.url = reverse("api:tenders-list") + "?token=approch" - cls.user_partner_with_token = UserFactory(email="approch@example.com", api_key="approch") + cls.api_token_approch = generate_random_string() + cls.url = reverse("api:tenders-list") + "?token=" + cls.api_token_approch + cls.user_partner_with_token = UserFactory(email="approch@example.com", api_key=cls.api_token_approch) def test_partner_approch_can_create_tender(self): with self.settings(PARTNER_APPROCH_USER_ID=self.user_partner_with_token.id): diff --git a/lemarche/api/tenders/views.py b/lemarche/api/tenders/views.py index 15965529c..9ab1d8b4b 100644 --- a/lemarche/api/tenders/views.py +++ b/lemarche/api/tenders/views.py @@ -2,9 +2,10 @@ from django.utils import timezone from drf_spectacular.utils import OpenApiParameter, extend_schema from rest_framework import mixins, viewsets +from rest_framework.permissions import IsAuthenticated from lemarche.api.tenders.serializers import TenderSerializer -from lemarche.api.utils import BasicChoiceSerializer, check_user_token +from lemarche.api.utils import BasicChoiceSerializer from lemarche.tenders import constants as tender_constants from lemarche.tenders.models import Tender from lemarche.users import constants as user_constants @@ -16,6 +17,7 @@ class TenderViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): + permission_classes = [IsAuthenticated] serializer_class = TenderSerializer @extend_schema( @@ -26,8 +28,6 @@ class TenderViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): ], ) def create(self, request, *args, **kwargs): - token = request.GET.get("token", None) - check_user_token(token) return super().create(request, args, kwargs) def perform_create(self, serializer: TenderSerializer): diff --git a/lemarche/api/tests.py b/lemarche/api/tests.py new file mode 100644 index 000000000..a7528b4b5 --- /dev/null +++ b/lemarche/api/tests.py @@ -0,0 +1,160 @@ +from django.http import HttpResponse +from django.test import RequestFactory, TestCase +from rest_framework.exceptions import AuthenticationFailed + +from lemarche.api.authentication import CustomBearerAuthentication, DeprecationWarningMiddleware +from lemarche.api.utils import generate_random_string +from lemarche.users.factories import UserFactory + + +class CustomBearerAuthenticationTest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.authentication = CustomBearerAuthentication() + + self.user_token = generate_random_string() + self.user = UserFactory(api_key=self.user_token) + self.url = "/api/endpoint/" + + def test_authentication_with_authorization_header(self): + """ + Test the authentication process using the Authorization header. + + This test simulates a GET request with a Bearer token in the Authorization header. + It verifies that the authentication method correctly identifies the user and token. + + Steps: + 1. Create a GET request to the specified URL. + 2. Add the Authorization header with a Bearer token. + 3. Authenticate the request. + 4. Assert that the returned user matches the expected user. + 5. Assert that the returned token matches the expected token. + """ + request = self.factory.get(self.url) + + request.headers = {"Authorization": "Bearer " + self.user_token} + + user, token = self.authentication.authenticate(request) + + self.assertEqual(user, self.user) + self.assertEqual(token, self.user_token) + + def test_authentication_with_url_token(self): + """ + Test the authentication process using a token provided in the URL. + + This test simulates a GET request with a token appended to the URL query string. + It verifies that the authentication mechanism correctly identifies the user and + token from the request. + + Assertions: + - The authenticated user should match the expected user. + - The token extracted from the request should match the expected user token. + """ + request = self.factory.get(self.url + "?token=" + self.user_token) + + user, token = self.authentication.authenticate(request) + + self.assertEqual(user, self.user) + self.assertEqual(token, self.user_token) + + def test_authentication_with_short_token(self): + """ + Test the authentication process with a short token. + + This test simulates a request with a token that is too short and verifies + that the authentication process raises an AuthenticationFailed exception + with the appropriate error message. + + Steps: + 1. Create a GET request with a short token in the Authorization header. + 2. Attempt to authenticate the request. + 3. Assert that an AuthenticationFailed exception is raised. + 4. Verify that the exception message is "Token too short. Possible security issue detected." + """ + # Requête avec un token trop court + request = self.factory.get(self.url) + request.headers = {"Authorization": "Bearer short"} + + with self.assertRaises(AuthenticationFailed) as context: + self.authentication.authenticate(request) + + self.assertEqual(str(context.exception), "Token too short. Possible security issue detected.") + + def test_authentication_with_invalid_token(self): + """ + Test the authentication process with an invalid token. + + This test ensures that the authentication mechanism correctly raises an + AuthenticationFailed exception when an invalid or expired token is provided + in the request headers. + + Steps: + 1. Create a GET request with an invalid token in the Authorization header. + 2. Attempt to authenticate the request. + 3. Verify that an AuthenticationFailed exception is raised. + 4. Check that the exception message is "Invalid or expired token". + """ + # Requête avec un token invalide + request = self.factory.get(self.url) + request.headers = {"Authorization": "Bearer in" + self.user_token} + + with self.assertRaises(AuthenticationFailed) as context: + self.authentication.authenticate(request) + + self.assertEqual(str(context.exception), "Invalid or expired token") + + def test_authentication_with_no_token(self): + """ + Test case for authentication without a token. + + This test verifies that the authentication method returns None + when no token is provided in the request. + """ + # Requête sans token + request = self.factory.get(self.url) + + result = self.authentication.authenticate(request) + + self.assertIsNone(result) + + +class DeprecationWarningMiddlewareTest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.middleware = DeprecationWarningMiddleware(lambda request: HttpResponse("Test response")) + + def test_no_deprecation_warning(self): + """ + Test that no deprecation warning is present in the response. + + This test sends a GET request to a specific API endpoint and checks + that the response does not contain a 'Deprecation-Warning' attribute. + """ + request = self.factory.get("/api/some-endpoint/") + response = self.middleware(request) + + self.assertFalse(hasattr(response, "Deprecation-Warning")) + + def test_with_deprecation_warning(self): + """ + Test that a deprecation warning is included in the response when the request + contains the _deprecated_auth_warning marker. + + This test simulates a request to an endpoint with the _deprecated_auth_warning + marker set to True. It then checks that the response includes a "Deprecation-Warning" + header with the expected deprecation message indicating that URL token authentication + is deprecated and will be removed by January 2025, and advises to use the Authorization + header with Bearer tokens instead. + """ + request = self.factory.get("/api/some-endpoint/") + request._deprecated_auth_warning = True # Ajouter le marqueur + + response = self.middleware(request) + + self.assertIn("Deprecation-Warning", response) + self.assertEqual( + response["Deprecation-Warning"], + "URL token authentication is deprecated and will be removed on 2025/01. " + "Please use Authorization header with Bearer tokens.", + ) diff --git a/lemarche/api/utils.py b/lemarche/api/utils.py index 68d39c01e..e4968737c 100644 --- a/lemarche/api/utils.py +++ b/lemarche/api/utils.py @@ -1,26 +1,7 @@ -from rest_framework import serializers -from rest_framework.exceptions import APIException - -from lemarche.users.models import User - - -# Custom Service Exceptions -class Unauthorized(APIException): - status_code = 401 - default_detail = "Unauthorized" - default_code = "unauthorized" - +import random +import string -def check_user_token(token): - """ - User token functionnality is temporary, and only used - to trace API usage and support : once a proper - auth protocol is implemented it will be replaced - """ - try: - return User.objects.has_api_key().get(api_key=token) - except (User.DoesNotExist, AssertionError): - raise Unauthorized +from rest_framework import serializers def custom_preprocessing_hook(endpoints): @@ -36,6 +17,10 @@ def custom_preprocessing_hook(endpoints): return filtered +def generate_random_string(n=64): + return "".join(random.choices(string.ascii_letters + string.digits, k=n)) + + class BasicChoiceSerializer(serializers.Serializer): id = serializers.CharField() name = serializers.CharField() diff --git a/lemarche/users/models.py b/lemarche/users/models.py index 6cbdb843b..804717f6d 100644 --- a/lemarche/users/models.py +++ b/lemarche/users/models.py @@ -5,7 +5,7 @@ from django.db import models from django.db.models import Count from django.db.models.functions import Greatest, Lower -from django.db.models.signals import post_save +from django.db.models.signals import post_save, pre_save from django.dispatch import receiver from django.forms.models import model_to_dict from django.utils import timezone @@ -410,6 +410,23 @@ def tender_siae_unread_count(self): return Tender.objects.unread(self).count() +@receiver(pre_save, sender=User) +def update_api_key_last_update(sender, instance, **kwargs): + """ + Before saving a user, add the to `api_key_last_updated` + """ + if instance.pk: + try: + old_instance = sender.objects.get(pk=instance.pk) + if old_instance.api_key != instance.api_key: + instance.api_key_last_updated = timezone.now() + except sender.DoesNotExist: + if instance.api_key: + instance.api_key_last_updated = timezone.now() + elif instance.api_key: + instance.api_key_last_updated = timezone.now() + + @receiver(post_save, sender=User) def user_post_save(sender, instance, **kwargs): if settings.BITOUBI_ENV == "prod":