Skip to content

Commit

Permalink
refactor: transform model vals & use args/kwargs
Browse files Browse the repository at this point in the history
First transform model values from enums to dataclasses (e.g. TopicEnum
-> TopicDefinition).

Then use unpacking of arguments and keyword arguments to make function
calls simpler.
  • Loading branch information
matthiasschaub committed Nov 17, 2024
1 parent 530c512 commit 7c2f24a
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 64 deletions.
30 changes: 11 additions & 19 deletions ohsome_quality_api/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ async def post_attribute_completeness(
if isinstance(parameters, AttributeCompletenessKeyRequest):
for attribute in parameters.attribute_keys:
validate_attribute_topic_combination(
attribute.value, parameters.topic_key.value
attribute,
parameters.topic,
)

return await _post_indicator(request, "attribute-completeness", parameters)
Expand Down Expand Up @@ -308,23 +309,12 @@ async def post_indicator(


async def _post_indicator(
request: Request, key: str, parameters: IndicatorRequest
request: Request,
key: str,
parameters: IndicatorRequest,
) -> Any:
validate_indicator_topic_combination(key, parameters.topic_key.value)
attribute_keys = getattr(parameters, "attribute_keys", None)
if attribute_keys:
attribute_keys = [attribute.value for attribute in attribute_keys]
attribute_filter = getattr(parameters, "attribute_filter", None)
attribute_names = getattr(parameters, "attribute_names", None)
indicators = await oqt.create_indicator(
key=key,
bpolys=parameters.bpolys,
topic=get_topic_preset(parameters.topic_key.value),
include_figure=parameters.include_figure,
attribute_keys=attribute_keys,
attribute_filter=attribute_filter,
attribute_names=attribute_names,
)
validate_indicator_topic_combination(key, parameters.topic)
indicators = await oqt.create_indicator(key=key, **dict(parameters))

if request.headers["accept"] == MEDIA_TYPE_JSON:
return {
Expand All @@ -345,10 +335,12 @@ async def _post_indicator(
}
else:
detail = "Content-Type needs to be either {0} or {1}".format(
MEDIA_TYPE_JSON, MEDIA_TYPE_GEOJSON
MEDIA_TYPE_JSON,
MEDIA_TYPE_GEOJSON,
)
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=detail
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=detail,
)


Expand Down
25 changes: 20 additions & 5 deletions ohsome_quality_api/api/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

import geojson
from geojson_pydantic import Feature, FeatureCollection, MultiPolygon, Polygon
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
)

from ohsome_quality_api.attributes.definitions import AttributeEnum
from ohsome_quality_api.topics.definitions import TopicEnum
from ohsome_quality_api.topics.models import TopicData
from ohsome_quality_api.topics.definitions import TopicEnum, get_topic_preset
from ohsome_quality_api.topics.models import TopicData, TopicDefinition
from ohsome_quality_api.utils.helper import snake_to_lower_camel


Expand Down Expand Up @@ -49,21 +54,26 @@ class BaseBpolys(BaseConfig):

@field_validator("bpolys")
@classmethod
def transform(cls, value) -> geojson.FeatureCollection:
def transform_bpolys(cls, value):
# NOTE: `geojson_pydantic` library is used only for validation and openAPI-spec
# generation. To avoid refactoring all code the FeatureCollection object of
# the `geojson` library is still used every else.
return geojson.loads(value.model_dump_json())


class IndicatorRequest(BaseBpolys):
topic_key: TopicEnum = Field(
topic: TopicEnum = Field(
...,
title="Topic Key",
alias="topic",
)
include_figure: bool = True

@field_validator("topic")
@classmethod
def transform_topic(cls, value) -> TopicDefinition:
return get_topic_preset(value.value)


class AttributeCompletenessKeyRequest(IndicatorRequest):
attribute_keys: List[AttributeEnum] = Field(
Expand All @@ -72,6 +82,11 @@ class AttributeCompletenessKeyRequest(IndicatorRequest):
alias="attributes",
)

@field_validator("attribute_keys")
@classmethod
def transform_attributes(cls, value) -> list[str]:
return [attribute.value for attribute in value]


class AttributeCompletenessFilterRequest(IndicatorRequest):
attribute_filter: str = Field(
Expand Down
34 changes: 14 additions & 20 deletions ohsome_quality_api/oqt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Controller for computing Indicators."""

import logging
from typing import Coroutine, List
from typing import Coroutine

from geojson import Feature, FeatureCollection

Expand All @@ -18,9 +18,8 @@ async def create_indicator(
bpolys: FeatureCollection,
topic: TopicData | TopicDefinition,
include_figure: bool = True,
attribute_keys: List[str] | None = None,
attribute_filter: str | None = None,
attribute_names: List[str] | None = None,
*args,
**kwargs,
) -> list[Indicator]:
"""Create indicator(s) for features of a GeoJSON FeatureCollection.
Expand All @@ -47,9 +46,8 @@ async def create_indicator(
feature,
topic,
include_figure,
attribute_keys,
attribute_filter,
attribute_names,
*args,
**kwargs,
)
)
return await gather_with_semaphore(tasks)
Expand All @@ -60,9 +58,8 @@ async def _create_indicator(
feature: Feature,
topic: Topic,
include_figure: bool = True,
attribute_keys: List[str] | None = None,
attribute_filter: str | None = None,
attribute_names: List[str] | None = None,
*args,
**kwargs,
) -> Indicator:
"""Create an indicator from scratch."""

Expand All @@ -71,19 +68,16 @@ async def _create_indicator(
logging.info("Feature id: {0:4}".format(feature.get("id", "None")))

indicator_class = get_class_from_key(class_type="indicator", key=key)
if key == "attribute-completeness":
indicator = indicator_class(
topic,
feature,
attribute_keys,
attribute_filter,
attribute_names,
)
else:
indicator = indicator_class(topic, feature)
indicator = indicator_class(
topic,
feature,
*args,
**kwargs,
)

logging.info("Run preprocessing")
await indicator.preprocess()

logging.info("Run calculation")
indicator.calculate()

Expand Down
10 changes: 5 additions & 5 deletions ohsome_quality_api/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from ohsome_quality_api.config import get_config_value
from ohsome_quality_api.indicators.definitions import get_valid_indicators
from ohsome_quality_api.topics.definitions import TopicEnum
from ohsome_quality_api.topics.models import BaseTopic
from ohsome_quality_api.utils.exceptions import (
AttributeTopicCombinationError,
GeoJSONError,
Expand All @@ -20,19 +20,19 @@
from ohsome_quality_api.utils.helper_geo import calculate_area


def validate_attribute_topic_combination(attribute: AttributeEnum, topic: TopicEnum):
def validate_attribute_topic_combination(attribute: AttributeEnum, topic: BaseTopic):
"""As attributes are only meaningful for a certain topic,
we need to check if the given combination is valid."""

valid_attributes_for_topic = get_attributes()[topic]
valid_attributes_for_topic = get_attributes()[topic.key]
valid_attribute_names = [attribute for attribute in valid_attributes_for_topic]

if attribute not in valid_attributes_for_topic:
raise AttributeTopicCombinationError(attribute, topic, valid_attribute_names)


def validate_indicator_topic_combination(indicator: str, topic: str):
if indicator not in get_valid_indicators(topic):
def validate_indicator_topic_combination(indicator: str, topic: BaseTopic):
if indicator not in get_valid_indicators(topic.key):
raise IndicatorTopicCombinationError(indicator, topic)


Expand Down
27 changes: 12 additions & 15 deletions tests/unittests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@
from tests.unittests.utils import get_geojson_fixture


def test_validate_attribute_topic_combination_with_valid_combination():
validate_attribute_topic_combination("maxspeed", "roads")


def test_validate_attribute_topic_combination_with_invalid_topic():
"""As the method under test requires individually valid arguments
the arguments given lead to a KeyError."""
with pytest.raises(KeyError):
validate_attribute_topic_combination("maxspeed", "xxxxx")
def test_validate_attribute_topic_combination_with_valid_combination(
topic_major_roads_length,
):
validate_attribute_topic_combination("maxspeed", topic_major_roads_length)


def test_validate_attribute_topic_combination_with_invalid_combination():
def test_validate_attribute_topic_combination_with_invalid_combination(
topic_building_count,
):
with pytest.raises(AttributeTopicCombinationError):
validate_attribute_topic_combination("maxspeed", "building-count")
validate_attribute_topic_combination("maxspeed", topic_building_count)


def test_validate_geojson_feature_collection_single(
Expand Down Expand Up @@ -77,13 +74,13 @@ def test_validate_geojson_unsupported_geometry_type(
validate_geojson(feature_collection_unsupported_geometry_type)


def test_validate_indicator_topic_combination():
validate_indicator_topic_combination("minimal", "minimal")
def test_validate_indicator_topic_combination(topic_minimal):
validate_indicator_topic_combination("minimal", topic_minimal)


def test_validate_indicator_topic_combination_invalid():
def test_validate_indicator_topic_combination_invalid(topic_building_count):
with pytest.raises(IndicatorTopicCombinationError):
validate_indicator_topic_combination("minimal", "building-count")
validate_indicator_topic_combination("minimal", topic_building_count)


@mock.patch.dict(
Expand Down

0 comments on commit 7c2f24a

Please sign in to comment.