From 6784e680c3c338fc1ed56ee2ef1aded53f8395ae Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Fri, 9 Aug 2024 16:38:31 +0200 Subject: [PATCH] Fix #385 -- Handle bulk creation when using reverse related name (#486) * Fix issue and add test coverage (#385) * Reuse existing test models, check the query count --------- Co-authored-by: Suraj Magdum --- CHANGELOG.md | 1 + model_bakery/baker.py | 14 ++++++++++++++ tests/test_baker.py | 15 ++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9862b12..b4d09e06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added ### Changed +- Handle bulk creation when using reverse related name ### Removed diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 2723ece6..7dd24da7 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -864,4 +864,18 @@ def bulk_create(baker: Baker[M], quantity: int, **kwargs) -> List[M]: for obj in kwargs[field.name] ] ) + + # set many-to-many relations that are specified using related name from kwargs + for field in baker.model._meta.get_fields(): + if field.many_to_many and hasattr(field, "related_model"): + reverse_relation_name = ( + field.related_query_name + or field.related_name + or f"{field.related_model._meta.model_name}_set" + ) + if reverse_relation_name in kwargs: + getattr(entry, reverse_relation_name).set( + kwargs[reverse_relation_name] + ) + return created_entries diff --git a/tests/test_baker.py b/tests/test_baker.py index 65323438..e0c48744 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -1056,8 +1056,8 @@ def test_annotation_within_manager_get_queryset_are_run_on_make(self): assert movie.title == movie.name +@pytest.mark.django_db class TestCreateM2MWhenBulkCreate(TestCase): - @pytest.mark.django_db def test_create(self): query_count = 12 with self.assertNumQueries(query_count): @@ -1068,6 +1068,19 @@ def test_create(self): c1, c2 = models.Classroom.objects.all()[:2] assert list(c1.students.all()) == list(c2.students.all()) == [person] + def test_make_should_create_objects_using_reverse_name(self): + classroom = baker.make(models.Classroom) + + with self.assertNumQueries(21): + students = baker.make( + models.Person, + classroom_set=[classroom], + _quantity=10, + _bulk_create=True, + ) + + assert students[0].classroom_set.count() == 1 + class TestBakerSeeded: @pytest.fixture()