From 57d96167c70c30563441994b8f5300f106902188 Mon Sep 17 00:00:00 2001 From: James Brandreth Date: Wed, 28 Jul 2021 14:48:51 +0100 Subject: [PATCH] Split features dynamic/static --- .../query_mixins/data_config_query_mixin.py | 12 +++- .../databases/queries/air_quality_forecast.py | 61 +++++++++++-------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/containers/cleanair/cleanair/mixins/query_mixins/data_config_query_mixin.py b/containers/cleanair/cleanair/mixins/query_mixins/data_config_query_mixin.py index 124700d7b..187f95713 100644 --- a/containers/cleanair/cleanair/mixins/query_mixins/data_config_query_mixin.py +++ b/containers/cleanair/cleanair/mixins/query_mixins/data_config_query_mixin.py @@ -19,7 +19,8 @@ def query_data_config( self, train_start_date: Optional[str] = None, pred_start_date: Optional[str] = None, - features: Optional[List[str]] = None, + static_features: Optional[List[str]] = None, + dynamic_features: Optional[List[str]] = None, ) -> Any: """Get the data ids and data configs that match the arguments. @@ -44,9 +45,14 @@ def query_data_config( AirQualityDataTable.data_config["pred_start_date"].astext >= pred_start_date ) - if features: + if static_features: data_ids = data_ids.filter( set(AirQualityDataTable.data_config["features"]) - == set(features) + == set(static_features) + ) + if dynamic_features: + data_ids = data_ids.filter( + set(AirQualityDataTable.data_config["features"]) + == set(dynamic_features) ) return data_ids diff --git a/containers/urbanair/urbanair/databases/queries/air_quality_forecast.py b/containers/urbanair/urbanair/databases/queries/air_quality_forecast.py index 79d8b7e99..2c125ab71 100644 --- a/containers/urbanair/urbanair/databases/queries/air_quality_forecast.py +++ b/containers/urbanair/urbanair/databases/queries/air_quality_forecast.py @@ -15,7 +15,7 @@ from cleanair.decorators import db_query from cleanair.mixins.query_mixins.data_config_query_mixin import AirQualityDataConfigQueryMixin from cleanair.mixins import InstanceQueryMixin, ResultQueryMixin -from cleanair.params.shared_params import PRODUCTION_DYNAMIC_FEATURES, PRODUCTION_STATIC_FEATURES, +from cleanair.params.shared_params import PRODUCTION_DYNAMIC_FEATURES, PRODUCTION_STATIC_FEATURES from ..database import all_or_404 from ..schemas.air_quality_forecast import ForecastResultGeoJson, GeometryGeoJson @@ -24,14 +24,19 @@ @db_query() def query_instance_ids( - db: Session, start_datetime: datetime, end_datetime: datetime, features: List[str] + db: Session, + start_datetime: datetime, + end_datetime: datetime, + static_features: List[str], + dynamic_features: List[str] ) -> Query: """ Check which model IDs produced forecasts between start_datetime and end_datetime. """ data_ids = AirQualityDataConfigQueryMixin().query_data_config( - features=PRODUCTION_STATIC_FEATURES+PRODUCTION_DYNAMIC_FEATURES + static_features=static_features, + dynamic_features=dynamic_features ).subquery() instance_query = InstanceQueryMixin().get_instances( @@ -81,7 +86,7 @@ def query_instance_ids( key=lambda _, *args, **kwargs: hashkey(*args, **kwargs), ) def cached_instance_ids( - db: Session, start_datetime: datetime, end_datetime: datetime, + db: Session, start_datetime: datetime, end_datetime: datetime, ) -> Optional[List[Tuple]]: """Cache available model instances""" logger.info( @@ -89,13 +94,19 @@ def cached_instance_ids( start_datetime, end_datetime, ) - # Replace this with mixin - return query_instance_ids(db, start_datetime, end_datetime).all() + + return query_instance_ids( + db, + start_datetime, + end_datetime, + PRODUCTION_STATIC_FEATURES, + PRODUCTION_DYNAMIC_FEATURES + ).all() @db_query() def query_geometries_hexgrid( - db: Session, bounding_box: Optional[Tuple[float]] = None, + db: Session, bounding_box: Optional[Tuple[float]] = None, ) -> Query: """ Query geometries for combining with plain JSON forecasts @@ -117,7 +128,7 @@ def query_geometries_hexgrid( cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs) ) def cached_geometries_hexgrid( - db: Session, bounding_box: Optional[Tuple[float]] = None, + db: Session, bounding_box: Optional[Tuple[float]] = None, ) -> GeometryGeoJson: """Cache geometries with optional bounding box""" logger.info("Querying hexgrid geometries") @@ -131,12 +142,12 @@ def cached_geometries_hexgrid( @db_query() def query_forecasts_hexgrid( - db: Session, - instance_id: str, - start_datetime: datetime, - end_datetime: datetime, - with_geometry: bool, - bounding_box: Optional[Tuple[float]] = None, + db: Session, + instance_id: str, + start_datetime: datetime, + end_datetime: datetime, + with_geometry: bool, + bounding_box: Optional[Tuple[float]] = None, ) -> Query: """ Get all forecasts for a given model instance in the given datetime range @@ -179,12 +190,12 @@ def query_forecasts_hexgrid( cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs) ) def cached_forecast_hexgrid_json( - db: Session, - instance_id: str, - start_datetime: datetime, - end_datetime: datetime, - with_geometry: bool, - bounding_box: Optional[Tuple[float]] = None, + db: Session, + instance_id: str, + start_datetime: datetime, + end_datetime: datetime, + with_geometry: bool, + bounding_box: Optional[Tuple[float]] = None, ) -> Optional[List[Tuple]]: """Cache forecasts with geometry with optional bounding box""" logger.info( @@ -210,11 +221,11 @@ def cached_forecast_hexgrid_json( cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs) ) def cached_forecast_hexgrid_geojson( - db: Session, - instance_id: str, - start_datetime: datetime, - end_datetime: datetime, - bounding_box: Optional[Tuple[float]] = None, + db: Session, + instance_id: str, + start_datetime: datetime, + end_datetime: datetime, + bounding_box: Optional[Tuple[float]] = None, ) -> ForecastResultGeoJson: """Cache forecasts with geometry with optional bounding box""" query_results = cached_forecast_hexgrid_json(