diff --git a/lemarche/labels/admin.py b/lemarche/labels/admin.py
index b65316d4d..ee5293996 100644
--- a/lemarche/labels/admin.py
+++ b/lemarche/labels/admin.py
@@ -1,5 +1,4 @@
from django.contrib import admin
-from django.db.models import Count
from django.urls import reverse
from django.utils.html import format_html, mark_safe
@@ -10,12 +9,12 @@
@admin.register(Label, site=admin_site)
class LabelAdmin(admin.ModelAdmin):
- list_display = ["id", "name", "nb_siaes", "created_at"]
+ list_display = ["id", "name", "siae_count_annotated_with_link", "created_at"]
search_fields = ["id", "name", "description"]
search_help_text = "Cherche sur les champs : ID, Nom, Description"
readonly_fields = [
- "nb_siaes",
+ "siae_count_annotated_with_link",
"logo_url_display",
"data_last_sync_date",
"logs_display",
@@ -31,7 +30,7 @@ class LabelAdmin(admin.ModelAdmin):
},
),
("Logo", {"fields": ("logo_url", "logo_url_display")}),
- ("Structures", {"fields": ("nb_siaes",)}),
+ ("Structures", {"fields": ("siae_count_annotated_with_link",)}),
(
"Source de données",
{
@@ -47,7 +46,7 @@ class LabelAdmin(admin.ModelAdmin):
def get_queryset(self, request):
qs = super().get_queryset(request)
- qs = qs.annotate(siae_count=Count("siaes", distinct=True))
+ qs = qs.with_siae_stats()
return qs
def get_readonly_fields(self, request, obj=None):
@@ -73,16 +72,16 @@ def logo_url_display(self, instance):
logo_url_display.short_description = "Logo"
- def nb_siaes(self, label):
- url = reverse("admin:siaes_siae_changelist") + f"?labels__id__exact={label.id}"
- return format_html(f'{label.siae_count}')
-
- nb_siaes.short_description = "Nombre de structures"
- nb_siaes.admin_order_field = "siae_count"
-
def logs_display(self, label=None):
if label:
return pretty_print_readonly_jsonfield(label.logs)
return "-"
logs_display.short_description = Label._meta.get_field("logs").verbose_name
+
+ def siae_count_annotated_with_link(self, label):
+ url = reverse("admin:siaes_siae_changelist") + f"?labels__id__exact={label.id}"
+ return format_html(f'{label.siae_count_annotated}')
+
+ siae_count_annotated_with_link.short_description = "Nombre de structures"
+ siae_count_annotated_with_link.admin_order_field = "siae_count_annotated"
diff --git a/lemarche/labels/factories.py b/lemarche/labels/factories.py
index ed026c652..5ed3edb13 100644
--- a/lemarche/labels/factories.py
+++ b/lemarche/labels/factories.py
@@ -11,3 +11,9 @@ class Meta:
name = factory.Faker("company", locale="fr_FR")
# slug: auto-generated
website = "https://example.com"
+
+ @factory.post_generation
+ def siaes(self, create, extracted, **kwargs):
+ if extracted:
+ # Add the iterable of groups using bulk addition
+ self.siaes.add(*extracted)
diff --git a/lemarche/labels/models.py b/lemarche/labels/models.py
index cd28e2a0f..65ea850fe 100644
--- a/lemarche/labels/models.py
+++ b/lemarche/labels/models.py
@@ -1,8 +1,14 @@
from django.db import models
+from django.db.models import Count
from django.template.defaultfilters import slugify
from django.utils import timezone
+class LabelQuerySet(models.QuerySet):
+ def with_siae_stats(self):
+ return self.annotate(siae_count_annotated=Count("siaes", distinct=True))
+
+
class Label(models.Model):
name = models.CharField(verbose_name="Nom", max_length=255)
slug = models.SlugField(verbose_name="Slug", max_length=255, unique=True)
@@ -18,6 +24,8 @@ class Label(models.Model):
created_at = models.DateTimeField(verbose_name="Date de création", default=timezone.now)
updated_at = models.DateTimeField(verbose_name="Date de modification", auto_now=True)
+ objects = models.Manager.from_queryset(LabelQuerySet)()
+
class Meta:
verbose_name = "Label & certification"
verbose_name_plural = "Labels & certifications"
diff --git a/lemarche/labels/tests.py b/lemarche/labels/tests.py
index 1badd0595..08f776956 100644
--- a/lemarche/labels/tests.py
+++ b/lemarche/labels/tests.py
@@ -1,6 +1,8 @@
from django.test import TestCase
from lemarche.labels.factories import LabelFactory
+from lemarche.labels.models import Label
+from lemarche.siaes.factories import SiaeFactory
class LabelModelTest(TestCase):
@@ -13,3 +15,17 @@ def test_slug_field(self):
def test_str(self):
self.assertEqual(str(self.label), "Un label")
+
+
+class LabelQuerySetTest(TestCase):
+ @classmethod
+ def setUpTestData(cls):
+ cls.siae_1 = SiaeFactory()
+ cls.siae_2 = SiaeFactory()
+ cls.label = LabelFactory()
+ cls.label_with_siaes = LabelFactory(siaes=[cls.siae_1, cls.siae_2])
+
+ def test_with_siae_stats(self):
+ label_queryset = Label.objects.with_siae_stats()
+ self.assertEqual(label_queryset.get(id=self.label.id).siae_count_annotated, 0)
+ self.assertEqual(label_queryset.get(id=self.label_with_siaes.id).siae_count_annotated, 2)