Skip to content

Commit

Permalink
Bulk creation of groups
Browse files Browse the repository at this point in the history
  • Loading branch information
kemar committed Dec 18, 2024
1 parent 5b48363 commit 19a39a9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
45 changes: 35 additions & 10 deletions iaso/diffing/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from dataclasses import dataclass


from iaso.models import OrgUnit, OrgUnitChangeRequest, DataSourceSynchronization
from iaso.models import Group, OrgUnit, OrgUnitChangeRequest, DataSourceSynchronization


logger = logging.getLogger(__name__)

Expand All @@ -34,14 +35,16 @@ def __init__(self, data_source_sync: DataSourceSynchronization):
self.diffs = json.loads(data_source_sync.json_diff)
self.change_requests_to_bulk_create = []
self.org_units_to_bulk_create = []
self.groups_to_bulk_create = []
self.change_requests_groups_to_bulk_create = {}
self.org_units_source_version_matching = {}
self.insert_batch_size = 100
self.json_batch_size = 10

def synchronize(self) -> None:
self._prepare_missing_org_units()
self._prepare_missing_org_units_and_groups()
self._bulk_create_missing_org_units()
self._bulk_create_missing_groups()
self._prepare_change_requests()
self._bulk_create_change_requests()
self._bulk_create_change_request_groups()
Expand All @@ -53,7 +56,7 @@ def _sort_by_path(self, diffs: dict) -> list:
def _parse_date(self, date_str: str) -> datetime.date:
return datetime.datetime.strptime(date_str, "%Y-%m-%d").date()

def _prepare_missing_org_units(self) -> None:
def _prepare_missing_org_units_and_groups(self) -> None:
# Cast the list into a generator to be able to iterate over it chunk by chunk.
missing_org_units_diff_generator = (diff for diff in self._sort_by_path(self.diffs) if diff["status"] == "new")

Expand All @@ -64,12 +67,16 @@ def _prepare_missing_org_units(self) -> None:
if not batch_diff:
break

org_unit_ids = [diff["orgunit_ref"]["id"] for diff in batch_diff]
org_units = OrgUnit.objects.filter(pk__in=org_unit_ids).select_related("parent").prefetch_related("groups")

for diff in batch_diff:
org_unit_id = diff["orgunit_ref"]["id"]
org_unit = OrgUnit.objects.select_related("parent").get(pk=org_unit_id)
org_unit = next(org_unit for org_unit in org_units if org_unit.id == org_unit_id)

new_parent = None
if org_unit.parent:
# Find the corresponding parent in the source version to update.
new_parent = OrgUnit.objects.get(
source_ref=org_unit.parent.source_ref, version=self.data_source_sync.source_version_to_update
)
Expand All @@ -80,6 +87,12 @@ def _prepare_missing_org_units(self) -> None:
old_id=org_unit_id,
)

# TODO: ensure uniqueness of groups.
for group in org_unit.groups.all():
group.pk = None
group.source_version = self.data_source_sync.source_version_to_update
self.groups_to_bulk_create.append(group)

# Duplicate the `OrgUnit` in the source version to update.
org_unit.pk = None
org_unit.validation_status = org_unit.VALIDATION_NEW
Expand All @@ -96,15 +109,27 @@ def _bulk_create_missing_org_units(self) -> None:
new_org_units_generator = (item for item in self.org_units_to_bulk_create)

while True:
batch = list(islice(new_org_units_generator, self.insert_batch_size))
batch_subset = list(islice(new_org_units_generator, self.insert_batch_size))

if not batch:
if not batch_subset:
break

new_org_units = OrgUnit.objects.bulk_create(batch, self.insert_batch_size)
new_org_units = OrgUnit.objects.bulk_create(batch_subset, self.insert_batch_size)
for new_org_unit in new_org_units:
self.org_units_source_version_matching[new_org_unit.source_ref].new_id = new_org_unit.pk

def _bulk_create_missing_groups(self) -> None:
# Cast the list into a generator to be able to iterate over it chunk by chunk.
new_groups_generator = (item for item in self.groups_to_bulk_create)

while True:
batch_subset = list(islice(new_groups_generator, self.json_batch_size))

if not batch_subset:
break

Group.objects.bulk_create(batch_subset, self.insert_batch_size)

def _prepare_change_requests(self) -> None:
# Cast the list into a generator to be able to iterate over it chunk by chunk.
change_requests_diff_generator = (diff for diff in self.diffs if diff["status"] in ["new", "modified"])
Expand Down Expand Up @@ -265,12 +290,12 @@ def _bulk_create_change_requests(self) -> None:
new_change_requests_generator = (item for item in self.change_requests_to_bulk_create)

while True:
batch = list(islice(new_change_requests_generator, self.insert_batch_size))
batch_subset = list(islice(new_change_requests_generator, self.insert_batch_size))

if not batch:
if not batch_subset:
break

change_requests = OrgUnitChangeRequest.objects.bulk_create(batch, self.insert_batch_size)
change_requests = OrgUnitChangeRequest.objects.bulk_create(batch_subset, self.insert_batch_size)

for change_request in change_requests:
groups_to_bulk_create = self.change_requests_groups_to_bulk_create.get(change_request.org_unit_id)
Expand Down
12 changes: 8 additions & 4 deletions iaso/tests/models/test_data_source_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def setUpTestData(cls):
opening_date=datetime.date(2022, 11, 28),
closed_date=datetime.date(2025, 11, 28),
)
cls.angola_country_to_update.calculate_paths()
cls.angola_country_to_update.groups.set([cls.group_b, cls.group_c])

cls.angola_region_to_update = m.OrgUnit.objects.create(
Expand All @@ -82,7 +81,6 @@ def setUpTestData(cls):
opening_date=datetime.date(2022, 11, 28),
closed_date=datetime.date(2025, 11, 28),
)
cls.angola_country_to_compare_with.calculate_paths()
cls.angola_country_to_compare_with.groups.set([cls.group_a])

cls.angola_region_to_compare_with = m.OrgUnit.objects.create(
Expand All @@ -95,7 +93,6 @@ def setUpTestData(cls):
opening_date=datetime.date(2022, 11, 28),
closed_date=datetime.date(2025, 11, 28),
)
cls.angola_region_to_compare_with.calculate_paths()

cls.angola_district_to_compare_with = m.OrgUnit.objects.create(
parent=cls.angola_region_to_compare_with,
Expand All @@ -107,11 +104,15 @@ def setUpTestData(cls):
opening_date=datetime.date(2022, 11, 28),
closed_date=datetime.date(2025, 11, 28),
)
cls.angola_district_to_compare_with.calculate_paths()
cls.angola_district_to_compare_with.groups.set([cls.group_a])

cls.account = m.Account.objects.create(name="Account")
cls.user = cls.create_user_with_profile(username="user", account=cls.account)

# Calculate paths.
cls.angola_country_to_update.calculate_paths()
cls.angola_country_to_compare_with.calculate_paths()

@time_machine.travel(DT, tick=False)
def test_create(self):
kwargs = {
Expand Down Expand Up @@ -333,3 +334,6 @@ def test_create_change_requests(self):
self.assertEqual(new_org_unit.parent.version, data_source_sync.source_version_to_update)
self.assertEqual(new_org_unit.creator, data_source_sync.created_by)
self.assertEqual(new_org_unit.validation_status, new_org_unit.VALIDATION_NEW)

new_group = m.Group.objects.get(name="Group A", source_version=data_source_sync.source_version_to_update)
self.assertEqual(new_group.org_units.count(), 0)

0 comments on commit 19a39a9

Please sign in to comment.