diff --git a/hub/management/commands/base_importers.py b/hub/management/commands/base_importers.py index 7a070aa54..3aeb48aa2 100644 --- a/hub/management/commands/base_importers.py +++ b/hub/management/commands/base_importers.py @@ -1,3 +1,4 @@ +from functools import cache from time import sleep from django.core.management.base import BaseCommand @@ -7,7 +8,7 @@ import pandas as pd from tqdm import tqdm -from hub.models import Area, AreaData, DataSet, DataType +from hub.models import Area, AreaData, AreaType, DataSet, DataType from utils.mapit import ( BadRequestException, ForbiddenException, @@ -20,6 +21,7 @@ class BaseAreaImportCommand(BaseCommand): cast_field = IntegerField + area_type = "WMC" def __init__(self): super().__init__() @@ -38,6 +40,10 @@ def delete_data(self): for data_type in self.data_types.values(): AreaData.objects.filter(data_type=data_type).delete() + @cache + def get_area_type(self): + return AreaType.objects.get(code=self.area_type) + def add_data_sets(self, df=None): for name, config in self.data_sets.items(): label = self.get_label(config) @@ -70,10 +76,12 @@ def add_data_sets(self, df=None): **config["defaults"], }, ) + data_set.areas_available.add(self.get_area_type()) data_type, created = DataType.objects.update_or_create( data_set=data_set, name=name, + area_type=self.get_area_type(), defaults={ "data_type": config["defaults"]["data_type"], "label": label, @@ -164,7 +172,6 @@ def update_max_min(self): class BaseImportFromDataFrameCommand(BaseAreaImportCommand): uses_gss = True - area_type = "WMC" def get_row_data(self, row, conf): return row[conf["col"]]