Skip to content

Commit

Permalink
Added task instance_bulk_gps_push
Browse files Browse the repository at this point in the history
Refactored the bulk gps check in order to call it inside the task.
Fixed typo in HasCreateOrgUnitPermission.
  • Loading branch information
tdethier committed Dec 20, 2024
1 parent 84d43de commit 606b33c
Show file tree
Hide file tree
Showing 9 changed files with 662 additions and 64 deletions.
1 change: 1 addition & 0 deletions hat/audit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ORG_UNIT_API_BULK = "org_unit_api_bulk"
GROUP_SET_API = "group_set_api"
INSTANCE_API = "instance_api"
INSTANCE_API_BULK = "instance_api_bulk"
FORM_API = "form_api"
GPKG_IMPORT = "gpkg_import"
CAMPAIGN_API = "campaign_api"
Expand Down
80 changes: 29 additions & 51 deletions iaso/api/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
)
from iaso.utils import timestamp_to_datetime
from iaso.utils.file_utils import get_file_type
from .org_units import HasCreateOrgUnitPermission

from ..models.forms import CR_MODE_IF_REFERENCE_FORM
from ..utils.models.common import get_creator_name
from ..utils.models.common import get_creator_name, check_instance_bulk_gps_push
from . import common
from .comment import UserSerializerForComment
from .common import (
Expand Down Expand Up @@ -91,7 +92,7 @@ def validate_period(self, value):

class HasInstancePermission(permissions.BasePermission):
def has_permission(self, request: Request, view):
if request.method == "POST":
if request.method == "POST": # to handle anonymous submissions sent by mobile
return True

return request.user.is_authenticated and (
Expand All @@ -112,6 +113,19 @@ def has_object_permission(self, request: Request, view, obj: Instance):
return False


class HasInstanceBulkPermission(permissions.BasePermission):
"""
Designed for POST endpoints that are not designed to receive new submissions.
"""
def has_permission(self, request: Request, view):
return request.user.is_authenticated and (
request.user.has_perm(permission.FORMS)
or request.user.has_perm(permission.SUBMISSIONS)
or request.user.has_perm(permission.REGISTRY_WRITE)
or request.user.has_perm(permission.REGISTRY_READ)
)


class InstanceFileSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
instance_id = serializers.IntegerField()
Expand Down Expand Up @@ -605,7 +619,11 @@ def bulkdelete(self, request):
status=201,
)

@action(detail=False, methods=["GET"], permission_classes=[permissions.IsAuthenticated, HasInstancePermission])
@action(
detail=False,
methods=["GET"],
permission_classes=[permissions.IsAuthenticated, HasInstanceBulkPermission, HasCreateOrgUnitPermission],
)
def check_bulk_gps_push(self, request):
# first, let's parse all parameters received from the URL
select_all, selected_ids, unselected_ids = self._parse_check_bulk_gps_push_parameters(request.GET)
Expand All @@ -628,48 +646,15 @@ def check_bulk_gps_push(self, request):
else:
instances_query = instances_query.exclude(pk__in=unselected_ids)

overwrite_ids = []
no_location_ids = []
org_units_to_instances_dict = {}
set_org_units_ids = set()
success, errors, warnings = check_instance_bulk_gps_push(instances_query)

for instance in instances_query:
if not instance.location:
no_location_ids.append(instance.id) # there is nothing to push to the OrgUnit
continue
if not success:
errors["result"] = "errors"
return Response(errors, status=status.HTTP_400_BAD_REQUEST)

org_unit = instance.org_unit
if org_unit.id in org_units_to_instances_dict:
# we can't push this instance's location since there was another instance linked to this OrgUnit
org_units_to_instances_dict[org_unit.id].append(instance.id)
continue
else:
org_units_to_instances_dict[org_unit.id] = [instance.id]

set_org_units_ids.add(org_unit.id)
if org_unit.location or org_unit.geom:
overwrite_ids.append(instance.id) # if the user proceeds, he will erase existing location
continue

# Before returning, we need to check if we've had multiple hits on an OrgUnit
error_same_org_unit_ids = self._check_bulk_gps_repeated_org_units(org_units_to_instances_dict)

if len(error_same_org_unit_ids):
return Response(
{"result": "error", "error_ids": error_same_org_unit_ids},
status=status.HTTP_400_BAD_REQUEST,
)

if len(no_location_ids) or len(overwrite_ids):
dict_response = {
"result": "warnings",
}
if len(no_location_ids):
dict_response["warning_no_location"] = no_location_ids
if len(overwrite_ids):
dict_response["warning_overwrite"] = overwrite_ids

return Response(dict_response, status=status.HTTP_200_OK)
if warnings:
warnings["result"] = "warnings"
return Response(warnings, status=status.HTTP_200_OK)

return Response(
{
Expand All @@ -680,7 +665,7 @@ def check_bulk_gps_push(self, request):

def _parse_check_bulk_gps_push_parameters(self, query_parameters):
raw_select_all = query_parameters.get("select_all", True)
select_all = raw_select_all not in ["false", "False", "0"]
select_all = raw_select_all not in ["false", "False", "0", 0, False]

raw_selected_ids = query_parameters.get("selected_ids", None)
if raw_selected_ids:
Expand All @@ -696,13 +681,6 @@ def _parse_check_bulk_gps_push_parameters(self, query_parameters):

return select_all, selected_ids, unselected_ids

def _check_bulk_gps_repeated_org_units(self, org_units_to_instance_ids: Dict[int, List[int]]) -> List[int]:
error_instance_ids = []
for _, instance_ids in org_units_to_instance_ids.items():
if len(instance_ids) >= 2:
error_instance_ids.extend(instance_ids)
return error_instance_ids

QUERY = """
select DATE_TRUNC('month', COALESCE(iaso_instance.source_created_at, iaso_instance.created_at)) as month,
(select name from iaso_form where id = iaso_instance.form_id) as form_name,
Expand Down
4 changes: 2 additions & 2 deletions iaso/api/org_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# noinspection PyMethodMayBeStatic


class HasCreateOrUnitPermission(permissions.BasePermission):
class HasCreateOrgUnitPermission(permissions.BasePermission):
def has_permission(self, request, view):
if not request.user.is_authenticated:
return False
Expand Down Expand Up @@ -614,7 +614,7 @@ def get_date(self, date: str) -> Union[datetime.date, None]:
pass
return None

@action(detail=False, methods=["POST"], permission_classes=[permissions.IsAuthenticated, HasCreateOrUnitPermission])
@action(detail=False, methods=["POST"], permission_classes=[permissions.IsAuthenticated, HasCreateOrgUnitPermission])
def create_org_unit(self, request):
"""This endpoint is used by the React frontend"""
errors = []
Expand Down
33 changes: 33 additions & 0 deletions iaso/api/tasks/create/instance_bulk_gps_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from rest_framework import viewsets, permissions, status
from rest_framework.response import Response

from iaso.api.instances import HasInstanceBulkPermission
from iaso.api.org_units import HasCreateOrgUnitPermission
from iaso.api.tasks import TaskSerializer
from iaso.tasks.instance_bulk_gps_push import instance_bulk_gps_push


class InstanceBulkGpsPushViewSet(viewsets.ViewSet):
"""Bulk push gps location from Instances to their related OrgUnit.
This task will override existing location on OrgUnits and might set `None` if the Instance doesn't have any location.
Calling this endpoint implies that the InstanceViewSet.check_bulk_gps_push() method has been called before and has returned no error.
"""

permission_classes = [permissions.IsAuthenticated, HasInstanceBulkPermission, HasCreateOrgUnitPermission]

def create(self, request):
raw_select_all = request.data.get("select_all", True)
select_all = raw_select_all not in [False, "false", "False", "0", 0]
selected_ids = request.data.get("selected_ids", [])
unselected_ids = request.data.get("unselected_ids", [])

user = self.request.user

task = instance_bulk_gps_push(
select_all=select_all, selected_ids=selected_ids, unselected_ids=unselected_ids, user=user
)
return Response(
{"task": TaskSerializer(instance=task).data},
status=status.HTTP_201_CREATED,
)
76 changes: 76 additions & 0 deletions iaso/tasks/instance_bulk_gps_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from copy import deepcopy
from logging import getLogger
from time import time
from typing import Optional, List

from django.contrib.auth.models import User
from django.db import transaction

from beanstalk_worker import task_decorator
from hat.audit import models as audit_models
from iaso.models import Task, Instance
from iaso.utils.gis import convert_2d_point_to_3d
from iaso.utils.models.common import check_instance_bulk_gps_push

logger = getLogger(__name__)


def push_single_instance_gps_to_org_unit(user: Optional[User], instance: Instance):
org_unit = instance.org_unit
original_copy = deepcopy(org_unit)
org_unit.location = convert_2d_point_to_3d(instance.location) if instance.location else None
org_unit.save()
if not original_copy.location:
logger.info(f"updating {org_unit.name} {org_unit.id} with {org_unit.location}")
else:
logger.info(
f"updating {org_unit.name} {org_unit.id} - overwriting {original_copy.location} with {org_unit.location}"
)
audit_models.log_modification(original_copy, org_unit, source=audit_models.INSTANCE_API_BULK, user=user)


@task_decorator(task_name="instance_bulk_gps_push")
def instance_bulk_gps_push(
select_all: bool,
selected_ids: List[int],
unselected_ids: List[int],
task: Task,
):
"""Background Task to bulk push instance gps to org units.
/!\ Danger: calling this task without having received a successful response from the check_bulk_gps_push
endpoint will have unexpected results that might cause data loss.
"""
start = time()
task.report_progress_and_stop_if_killed(progress_message="Searching for Instances for pushing gps data")

user = task.launcher

queryset = Instance.non_deleted_objects.get_queryset().filter_for_user(user)
queryset = queryset.select_related("org_unit")

if not select_all:
queryset = queryset.filter(pk__in=selected_ids)
else:
queryset = queryset.exclude(pk__in=unselected_ids)

if not queryset:
raise Exception("No matching instances found")

# Checking if any gps push can be performed with what was requested
success, errors, _ = check_instance_bulk_gps_push(queryset)
if not success:
raise Exception("Cannot proceed with the gps push due to errors: %s" % errors)

total = queryset.count()

with transaction.atomic():
for index, instance in enumerate(queryset.iterator()):
res_string = "%.2f sec, processed %i instances" % (time() - start, index)
task.report_progress_and_stop_if_killed(progress_message=res_string, end_value=total, progress_value=index)
push_single_instance_gps_to_org_unit(
user,
instance,
)

task.report_success(message="%d modified" % total)
63 changes: 52 additions & 11 deletions iaso/tests/api/test_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUpTestData(cls):
cls.sw_version = sw_version

cls.yoda = cls.create_user_with_profile(
username="yoda", last_name="Da", first_name="Yo", account=star_wars, permissions=["iaso_submissions"]
username="yoda", last_name="Da", first_name="Yo", account=star_wars, permissions=["iaso_submissions", "iaso_org_units"]
)
cls.guest = cls.create_user_with_profile(username="guest", account=star_wars, permissions=["iaso_submissions"])
cls.supervisor = cls.create_user_with_profile(
Expand Down Expand Up @@ -72,10 +72,10 @@ def setUpTestData(cls):
version=sw_version,
)
cls.jedi_council_endor = m.OrgUnit.objects.create(
name="Endor Jedi Council", source_ref="jedi_council_endor_ref"
name="Endor Jedi Council", source_ref="jedi_council_endor_ref", version=sw_version,
)
cls.jedi_council_endor_region = m.OrgUnit.objects.create(
name="Endor Region Jedi Council", parent=cls.jedi_council_endor, source_ref="jedi_council_endor_region_ref"
name="Endor Region Jedi Council", parent=cls.jedi_council_endor, source_ref="jedi_council_endor_region_ref", version=sw_version,
)

cls.project = m.Project.objects.create(
Expand Down Expand Up @@ -1963,8 +1963,8 @@ def test_check_bulk_push_gps_select_all_ok(self):
response_json = response.json()
self.assertEqual(response_json["result"], "success")

def test_check_bulk_push_gps_select_all_error(self):
# setting gps data for instances that were not deleted
def test_check_bulk_push_gps_select_all_error_same_org_unit(self):
# changing location for some instances to have multiple hits on multiple org_units
self.instance_1.org_unit = self.jedi_council_endor
self.instance_2.org_unit = self.jedi_council_endor
new_location = Point(1, 2, 3)
Expand All @@ -1973,14 +1973,55 @@ def test_check_bulk_push_gps_select_all_error(self):
instance.save()
instance.refresh_from_db()

# Let's delete some instances, the result will be the same
for instance in [self.instance_6, self.instance_8]:
instance.deleted_at = datetime.datetime.now()
instance.deleted = True
instance.save()

self.client.force_authenticate(self.yoda)
response = self.client.get(f"/api/instances/check_bulk_gps_push/") # by default, select_all = True
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response_json = response.json()
self.assertEqual(response_json["result"], "error")
self.assertCountEqual(response_json["error_ids"], [self.instance_1.id, self.instance_2.id])
self.assertEqual(response_json["result"], "errors")
self.assertCountEqual(response_json["error_same_org_unit"], [self.instance_1.id, self.instance_2.id, self.instance_3.id, self.instance_4.id, self.instance_5.id])

def test_check_bulk_push_gps_select_all_error_read_only_source(self):
# Making the source read only
self.sw_source.read_only = True
self.sw_source.save()

# Changing some instance.org_unit so that all the results don't appear only in "error_same_org_unit"
self.instance_2.org_unit = self.jedi_council_endor
self.instance_3.org_unit = self.jedi_council_endor_region
self.instance_8.org_unit = self.ou_top_1
for instance in [self.instance_2, self.instance_3, self.instance_8]:
instance.save()
instance.refresh_from_db()

self.client.force_authenticate(self.yoda)
response = self.client.get(f"/api/instances/check_bulk_gps_push/") # by default, select_all = True
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response_json = response.json()
self.assertEqual(response_json["result"], "errors")
# instance_6 included because it's the first one with the remaining org_unit and the queryset has a default order of "-id"
self.assertCountEqual(response_json["error_read_only_source"], [self.instance_8.id, self.instance_2.id, self.instance_3.id, self.instance_6.id])

def test_check_bulk_push_gps_select_all_warning_no_location(self):
# Changing some instance.org_unit so that all the results don't appear only in "error_same_org_unit"
self.instance_2.org_unit = self.jedi_council_endor
self.instance_3.org_unit = self.jedi_council_endor_region
self.instance_8.org_unit = self.ou_top_1
for instance in [self.instance_2, self.instance_3, self.instance_8]:
instance.save()
instance.refresh_from_db()

# Let's delete some instances to avoid getting "error_same_org-unit"
for instance in [self.instance_4, self.instance_5, self.instance_6, self.instance_8]:
instance.deleted_at = datetime.datetime.now()
instance.deleted = True
instance.save()

self.client.force_authenticate(self.yoda)
response = self.client.get(f"/api/instances/check_bulk_gps_push/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down Expand Up @@ -2095,8 +2136,8 @@ def test_check_bulk_push_gps_selected_ids_error(self):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response_json = response.json()
# All these Instances target the same OrgUnit, so it's impossible to push gps data
self.assertEqual(response_json["result"], "error")
self.assertCountEqual(response_json["error_ids"], [self.instance_1.id, self.instance_2.id, self.instance_3.id])
self.assertEqual(response_json["result"], "errors")
self.assertCountEqual(response_json["error_same_org_unit"], [self.instance_1.id, self.instance_2.id, self.instance_3.id])

def test_check_bulk_push_gps_selected_ids_error_unknown_id(self):
self.client.force_authenticate(self.yoda)
Expand Down Expand Up @@ -2254,8 +2295,8 @@ def test_check_bulk_push_gps_unselected_ids_error(self):
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response_json = response.json()
self.assertEqual(response_json["result"], "error")
self.assertCountEqual(response_json["error_ids"], [self.instance_1.id, self.instance_2.id])
self.assertEqual(response_json["result"], "errors")
self.assertCountEqual(response_json["error_same_org_unit"], [self.instance_1.id, self.instance_2.id])

def test_check_bulk_push_gps_unselected_ids_error_unknown_id(self):
self.client.force_authenticate(self.yoda)
Expand Down
Loading

0 comments on commit 606b33c

Please sign in to comment.