diff --git a/lemarche/labels/admin.py b/lemarche/labels/admin.py index b65316d4d..ee5293996 100644 --- a/lemarche/labels/admin.py +++ b/lemarche/labels/admin.py @@ -1,5 +1,4 @@ from django.contrib import admin -from django.db.models import Count from django.urls import reverse from django.utils.html import format_html, mark_safe @@ -10,12 +9,12 @@ @admin.register(Label, site=admin_site) class LabelAdmin(admin.ModelAdmin): - list_display = ["id", "name", "nb_siaes", "created_at"] + list_display = ["id", "name", "siae_count_annotated_with_link", "created_at"] search_fields = ["id", "name", "description"] search_help_text = "Cherche sur les champs : ID, Nom, Description" readonly_fields = [ - "nb_siaes", + "siae_count_annotated_with_link", "logo_url_display", "data_last_sync_date", "logs_display", @@ -31,7 +30,7 @@ class LabelAdmin(admin.ModelAdmin): }, ), ("Logo", {"fields": ("logo_url", "logo_url_display")}), - ("Structures", {"fields": ("nb_siaes",)}), + ("Structures", {"fields": ("siae_count_annotated_with_link",)}), ( "Source de données", { @@ -47,7 +46,7 @@ class LabelAdmin(admin.ModelAdmin): def get_queryset(self, request): qs = super().get_queryset(request) - qs = qs.annotate(siae_count=Count("siaes", distinct=True)) + qs = qs.with_siae_stats() return qs def get_readonly_fields(self, request, obj=None): @@ -73,16 +72,16 @@ def logo_url_display(self, instance): logo_url_display.short_description = "Logo" - def nb_siaes(self, label): - url = reverse("admin:siaes_siae_changelist") + f"?labels__id__exact={label.id}" - return format_html(f'{label.siae_count}') - - nb_siaes.short_description = "Nombre de structures" - nb_siaes.admin_order_field = "siae_count" - def logs_display(self, label=None): if label: return pretty_print_readonly_jsonfield(label.logs) return "-" logs_display.short_description = Label._meta.get_field("logs").verbose_name + + def siae_count_annotated_with_link(self, label): + url = reverse("admin:siaes_siae_changelist") + f"?labels__id__exact={label.id}" + return format_html(f'{label.siae_count_annotated}') + + siae_count_annotated_with_link.short_description = "Nombre de structures" + siae_count_annotated_with_link.admin_order_field = "siae_count_annotated" diff --git a/lemarche/labels/factories.py b/lemarche/labels/factories.py index ed026c652..5ed3edb13 100644 --- a/lemarche/labels/factories.py +++ b/lemarche/labels/factories.py @@ -11,3 +11,9 @@ class Meta: name = factory.Faker("company", locale="fr_FR") # slug: auto-generated website = "https://example.com" + + @factory.post_generation + def siaes(self, create, extracted, **kwargs): + if extracted: + # Add the iterable of groups using bulk addition + self.siaes.add(*extracted) diff --git a/lemarche/labels/models.py b/lemarche/labels/models.py index cd28e2a0f..65ea850fe 100644 --- a/lemarche/labels/models.py +++ b/lemarche/labels/models.py @@ -1,8 +1,14 @@ from django.db import models +from django.db.models import Count from django.template.defaultfilters import slugify from django.utils import timezone +class LabelQuerySet(models.QuerySet): + def with_siae_stats(self): + return self.annotate(siae_count_annotated=Count("siaes", distinct=True)) + + class Label(models.Model): name = models.CharField(verbose_name="Nom", max_length=255) slug = models.SlugField(verbose_name="Slug", max_length=255, unique=True) @@ -18,6 +24,8 @@ class Label(models.Model): created_at = models.DateTimeField(verbose_name="Date de création", default=timezone.now) updated_at = models.DateTimeField(verbose_name="Date de modification", auto_now=True) + objects = models.Manager.from_queryset(LabelQuerySet)() + class Meta: verbose_name = "Label & certification" verbose_name_plural = "Labels & certifications" diff --git a/lemarche/labels/tests.py b/lemarche/labels/tests.py index 1badd0595..08f776956 100644 --- a/lemarche/labels/tests.py +++ b/lemarche/labels/tests.py @@ -1,6 +1,8 @@ from django.test import TestCase from lemarche.labels.factories import LabelFactory +from lemarche.labels.models import Label +from lemarche.siaes.factories import SiaeFactory class LabelModelTest(TestCase): @@ -13,3 +15,17 @@ def test_slug_field(self): def test_str(self): self.assertEqual(str(self.label), "Un label") + + +class LabelQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.siae_1 = SiaeFactory() + cls.siae_2 = SiaeFactory() + cls.label = LabelFactory() + cls.label_with_siaes = LabelFactory(siaes=[cls.siae_1, cls.siae_2]) + + def test_with_siae_stats(self): + label_queryset = Label.objects.with_siae_stats() + self.assertEqual(label_queryset.get(id=self.label.id).siae_count_annotated, 0) + self.assertEqual(label_queryset.get(id=self.label_with_siaes.id).siae_count_annotated, 2)