Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forecast endpoint #734

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Mixin for querying from the air quality data config table."""

from __future__ import annotations
from typing import Optional, Any, TYPE_CHECKING
from typing import Optional, Any, TYPE_CHECKING, List
from ...databases.tables import AirQualityDataTable
from ...decorators import db_query

Expand All @@ -19,6 +19,8 @@ def query_data_config(
self,
train_start_date: Optional[str] = None,
pred_start_date: Optional[str] = None,
static_features: Optional[List[str]] = None,
dynamic_features: Optional[List[str]] = None,
) -> Any:
"""Get the data ids and data configs that match the arguments.

Expand All @@ -30,16 +32,26 @@ def query_data_config(
A database query with columns for data id, data config and preprocessing.
"""
with self.dbcnxn.open_session() as session:
readings = session.query(AirQualityDataTable)
data_ids = session.query(AirQualityDataTable)
if train_start_date:
# NOTE uses a json b query to get the entry of the dictionary
readings = readings.filter(
data_ids = data_ids.filter(
AirQualityDataTable.data_config["train_start_date"].astext
>= train_start_date
)
if pred_start_date:
readings = readings.filter(
data_ids = data_ids.filter(
AirQualityDataTable.data_config["pred_start_date"].astext
>= pred_start_date
)
return readings
if static_features:
data_ids = data_ids.filter(
AirQualityDataTable.data_config["features"]
== static_features
)
if dynamic_features:
data_ids = data_ids.filter(
AirQualityDataTable.data_config["features"]
== dynamic_features
)
return data_ids
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
HexGrid,
)
from cleanair.decorators import db_query

from cleanair.mixins.query_mixins.data_config_query_mixin import AirQualityDataConfigQueryMixin
from cleanair.mixins import InstanceQueryMixin, ResultQueryMixin
from cleanair.params.shared_params import PRODUCTION_DYNAMIC_FEATURES, PRODUCTION_STATIC_FEATURES
from ..database import all_or_404
from ..schemas.air_quality_forecast import ForecastResultGeoJson, GeometryGeoJson

Expand All @@ -22,28 +24,54 @@

@db_query()
def query_instance_ids(
db: Session, start_datetime: datetime, end_datetime: datetime,
db: Session,
start_datetime: datetime,
end_datetime: datetime,
static_features: List[str],
dynamic_features: List[str]
) -> Query:
"""
Check which model IDs produced forecasts between start_datetime and end_datetime.
"""
query = (
db.query(
AirQualityResultTable.instance_id,
AirQualityResultTable.measurement_start_utc,
)
.join(
AirQualityInstanceTable,
AirQualityInstanceTable.instance_id == AirQualityResultTable.instance_id,
)
.filter(
AirQualityInstanceTable.tag == "production",
AirQualityInstanceTable.model_name == "svgp",
AirQualityResultTable.measurement_start_utc >= start_datetime,
AirQualityResultTable.measurement_start_utc < end_datetime,
)

data_ids = AirQualityDataConfigQueryMixin().query_data_config(
static_features=static_features,
dynamic_features=dynamic_features
).subquery()

instance_query = InstanceQueryMixin().get_instances(
tag='production',
models=['mrdgp'],
data_ids=data_ids,
).subquery()

results_query = ResultQueryMixin().query_results(
start=start_datetime,
upto=end_datetime,
).subquery()
Copy link
Collaborator

@PatrickOHara PatrickOHara Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps because this is a subquery? Not a query

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that might be it


query = results_query.innerjoin(
instance_query,
instance_query.c.instance_id == results_query.c.instance_id
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which would mean the type of query is also a subquery


# query = (
# db.query(
# AirQualityResultTable.instance_id,
# AirQualityResultTable.measurement_start_utc,
# )
# .join(
# AirQualityInstanceTable,
# AirQualityInstanceTable.instance_id == AirQualityResultTable.instance_id,
# )
# .filter(
# AirQualityInstanceTable.tag == "production",
# AirQualityInstanceTable.model_name == "mrdgp",
# AirQualityResultTable.measurement_start_utc >= start_datetime,
# AirQualityResultTable.measurement_start_utc < end_datetime,
# )
# )

# Return only instance IDs and distinct values
query = query.with_entities(
AirQualityInstanceTable.instance_id, AirQualityInstanceTable.fit_start_time
Expand All @@ -58,20 +86,27 @@ def query_instance_ids(
key=lambda _, *args, **kwargs: hashkey(*args, **kwargs),
)
def cached_instance_ids(
db: Session, start_datetime: datetime, end_datetime: datetime,
db: Session, start_datetime: datetime, end_datetime: datetime,
) -> Optional[List[Tuple]]:
"""Cache available model instances"""
logger.info(
"Querying available instance IDs between %s and %s",
start_datetime,
end_datetime,
)
return query_instance_ids(db, start_datetime, end_datetime).all()

return query_instance_ids(
db,
start_datetime,
end_datetime,
PRODUCTION_STATIC_FEATURES,
PRODUCTION_DYNAMIC_FEATURES
).all()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You haven't specified output_type="..." as an argument. At the moment I think it returns a query object. Are you expecting a query?



@db_query()
def query_geometries_hexgrid(
db: Session, bounding_box: Optional[Tuple[float]] = None,
db: Session, bounding_box: Optional[Tuple[float]] = None,
) -> Query:
"""
Query geometries for combining with plain JSON forecasts
Expand All @@ -93,7 +128,7 @@ def query_geometries_hexgrid(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_geometries_hexgrid(
db: Session, bounding_box: Optional[Tuple[float]] = None,
db: Session, bounding_box: Optional[Tuple[float]] = None,
) -> GeometryGeoJson:
"""Cache geometries with optional bounding box"""
logger.info("Querying hexgrid geometries")
Expand All @@ -107,12 +142,12 @@ def cached_geometries_hexgrid(

@db_query()
def query_forecasts_hexgrid(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
) -> Query:
"""
Get all forecasts for a given model instance in the given datetime range
Expand Down Expand Up @@ -155,12 +190,12 @@ def query_forecasts_hexgrid(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_forecast_hexgrid_json(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
with_geometry: bool,
bounding_box: Optional[Tuple[float]] = None,
) -> Optional[List[Tuple]]:
"""Cache forecasts with geometry with optional bounding box"""
logger.info(
Expand All @@ -186,11 +221,11 @@ def cached_forecast_hexgrid_json(
cache=LRUCache(maxsize=256), key=lambda _, *args, **kwargs: hashkey(*args, **kwargs)
)
def cached_forecast_hexgrid_geojson(
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
bounding_box: Optional[Tuple[float]] = None,
db: Session,
instance_id: str,
start_datetime: datetime,
end_datetime: datetime,
bounding_box: Optional[Tuple[float]] = None,
) -> ForecastResultGeoJson:
"""Cache forecasts with geometry with optional bounding box"""
query_results = cached_forecast_hexgrid_json(
Expand Down