Skip to content

Commit

Permalink
Split features dynamic/static
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbrandreth committed Jul 28, 2021
1 parent a992738 commit 57d9616
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def query_data_config(
self,
train_start_date: Optional[str] = None,
pred_start_date: Optional[str] = None,
features: Optional[List[str]] = None,
static_features: Optional[List[str]] = None,
dynamic_features: Optional[List[str]] = None,
) -> Any:
"""Get the data ids and data configs that match the arguments.
Expand All @@ -44,9 +45,14 @@ def query_data_config(
AirQualityDataTable.data_config["pred_start_date"].astext
>= pred_start_date
)
if features:
if static_features:
data_ids = data_ids.filter(
set(AirQualityDataTable.data_config["features"])
== set(features)
== set(static_features)
)
if dynamic_features:
data_ids = data_ids.filter(
set(AirQualityDataTable.data_config["features"])
== set(dynamic_features)
)
return data_ids
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cleanair.decorators import db_query
from cleanair.mixins.query_mixins.data_config_query_mixin import AirQualityDataConfigQueryMixin
from cleanair.mixins import InstanceQueryMixin, ResultQueryMixin
from cleanair.params.shared_params import PRODUCTION_DYNAMIC_FEATURES, PRODUCTION_STATIC_FEATURES,
from cleanair.params.shared_params import PRODUCTION_DYNAMIC_FEATURES, PRODUCTION_STATIC_FEATURES
from ..database import all_or_404
from ..schemas.air_quality_forecast import ForecastResultGeoJson, GeometryGeoJson

Expand All @@ -24,14 +24,19 @@

@db_query()
def query_instance_ids(
db: Session, start_datetime: datetime, end_datetime: datetime, features: List[str]
db: Session,
start_datetime: datetime,
end_datetime: datetime,
static_features: List[str],
dynamic_features: List[str]
) -> Query:
"""
Check which model IDs produced forecasts between start_datetime and end_datetime.
"""

data_ids = AirQualityDataConfigQueryMixin().query_data_config(
features=PRODUCTION_STATIC_FEATURES+PRODUCTION_DYNAMIC_FEATURES
static_features=static_features,
dynamic_features=dynamic_features
).subquery()

instance_query = InstanceQueryMixin().get_instances(
Expand Down Expand Up @@ -81,21 +86,27 @@ def query_instance_ids(
key=lambda _, *args, **kwargs: hashkey(*args, **kwargs),
)
def cached_instance_ids(
db: Session, start_datetime: datetime, end_datetime: datetime,
db: Session, start_datetime: datetime, end_datetime: datetime,
) -> Optional[List[Tuple]]:
"""Cache available model instances"""
logger.info(
"Querying available instance IDs between %s and %s",
start_datetime,
end_datetime,
)
# Replace this with mixin
return query_instance_ids(db, start_datetime, end_datetime).all()

return query_instance_ids(
db,
start_datetime,
end_datetime,
PRODUCTION_STATIC_FEATURES,
PRODUCTION_DYNAMIC_FEATURES
).all()


@db_query()
def query_geometries_hexgrid(
db: Session, bounding_box: Optional[Tuple[float]] = None,
db: Session, bounding_box: Optional[Tuple[float]] = None,
) -> Query:
"""
Query geometries for combining with plain JSON forecasts
Expand All @@ -117,7 +128,7 @@ def query_geometries_hexgrid(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_geometries_hexgrid(
db: Session, bounding_box: Optional[Tuple[float]] = None,
db: Session, bounding_box: Optional[Tuple[float]] = None,
) -> GeometryGeoJson:
"""Cache geometries with optional bounding box"""
logger.info("Querying hexgrid geometries")
Expand All @@ -131,12 +142,12 @@ def cached_geometries_hexgrid(

@db_query()
def query_forecasts_hexgrid(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
) -> Query:
"""
Get all forecasts for a given model instance in the given datetime range
Expand Down Expand Up @@ -179,12 +190,12 @@ def query_forecasts_hexgrid(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_forecast_hexgrid_json(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
) -> Optional[List[Tuple]]:
"""Cache forecasts with geometry with optional bounding box"""
logger.info(
Expand All @@ -210,11 +221,11 @@ def cached_forecast_hexgrid_json(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_forecast_hexgrid_geojson(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
bounding_box: Optional[Tuple[float]] = None,
) -> ForecastResultGeoJson:
"""Cache forecasts with geometry with optional bounding box"""
query_results = cached_forecast_hexgrid_json(
Expand Down

0 comments on commit 57d9616

Please sign in to comment.