diff --git a/datahub/dsr.py b/datahub/dsr.py index 2091da0..ee07de0 100644 --- a/datahub/dsr.py +++ b/datahub/dsr.py @@ -7,28 +7,30 @@ from numpy.typing import NDArray from pydantic import BaseModel, Field +from . import log + class DSRModel(BaseModel): """Define required key values for Demand Side Response data.""" - amount: list = Field(alias="Amount", shape=(13,)) - cost: list = Field(alias="Cost", shape=(1440, 13)) - kwh_cost: list = Field(alias="kWh Cost", shape=(2,)) - activities: list = Field(alias="Activities", shape=(1440, 7)) + amount: list = Field(alias="Amount", shape=(1, 13)) + cost: list = Field(alias="Cost", shape=(13, 1440)) + kwh_cost: list = Field(alias="kWh Cost", shape=(2, 1)) + activities: list = Field(alias="Activities", shape=(7, 1440)) activities_outside_home: list = Field( - alias="Activities Outside Home", shape=(1440, 7) + alias="Activities Outside Home", shape=(7, 1440) ) - activity_types: list = Field(alias="Activity Types", shape=(7,)) - ev_id_matrix: list = Field(alias="EV ID Matrix", default=None, shape=(1440, 4329)) - ev_dt: list = Field(alias="EV DT", shape=(1440, 2)) - ev_locations: list = Field(alias="EV Locations", default=None, shape=(1440, 4329)) - ev_battery: list = Field(alias="EV Battery", default=None, shape=(1440, 4329)) - ev_state: list = Field(alias="EV State", shape=(1440, 4329)) - ev_mask: list = Field(alias="EV Mask", default=None, shape=(1440, 4329)) - baseline_ev: list = Field(alias="Baseline EV", shape=(1440,)) - baseline_non_ev: list = Field(alias="Baseline Non-EV", shape=(1440,)) - actual_ev: list = Field(alias="Actual EV", shape=(1440,)) - actual_non_ev: list = Field(alias="Actual Non-EV", shape=(1440,)) + activity_types: list = Field(alias="Activity Types", shape=(1, 7)) + ev_id_matrix: list = Field(alias="EV ID Matrix", default=[], shape=(None, 1440)) + ev_dt: list = Field(alias="EV DT", shape=(2, 1440)) + ev_locations: list = Field(alias="EV Locations", default=[], shape=(None, 1440)) + ev_battery: list = Field(alias="EV Battery", default=[], shape=(None, 1440)) + ev_state: list = Field(alias="EV State", shape=(None, 1440)) + ev_mask: list = Field(alias="EV Mask", default=[], shape=(None, 1440)) + baseline_ev: list = Field(alias="Baseline EV", shape=(1, 1440)) + baseline_non_ev: list = Field(alias="Baseline Non-EV", shape=(1, 1440)) + actual_ev: list = Field(alias="Actual EV", shape=(1, 1440)) + actual_non_ev: list = Field(alias="Actual Non-EV", shape=(1, 1440)) name: str = Field(alias="Name", default="") warn: str = Field(alias="Warn", default="") @@ -54,6 +56,7 @@ def validate_dsr_data(data: dict[str, NDArray | str]) -> None: Raises: A HTTPException is there are mising failing fields if there are. """ + log.debug("Validating DSR data") missing_fields = [ field for field in DSRModel.schema()["required"] if field not in data.keys() ] @@ -68,18 +71,30 @@ def validate_dsr_data(data: dict[str, NDArray | str]) -> None: try: array = data[alias] except KeyError: - if field: + if "default" not in field.keys(): aliases.append(alias) + log.error(f"Missing '{alias}' data") continue if field["type"] == "array" and not isinstance(array, str): - if array.shape != field["shape"] or not np.issubdtype( - array.dtype, np.number + shape = field["shape"] + if shape[0] is None: + shape = (array.shape[0], shape[1]) + if array.shape != shape: + aliases.append(alias) + log.error(f"'{alias}' has shape {array.shape}, expected {shape}") + continue + if not np.issubdtype(array.dtype, np.number) and not np.issubdtype( + array.dtype, np.character ): aliases.append(alias) + log.error( + f"'{alias}' is type {array.dtype}, expected number or character" + ) + if aliases: raise HTTPException( status_code=422, - detail=f"Invalid size for: {', '.join(aliases)}.", + detail=f"Invalid size or data type for: {', '.join(aliases)}.", ) diff --git a/datahub/main.py b/datahub/main.py index e26b7a6..49078d3 100644 --- a/datahub/main.py +++ b/datahub/main.py @@ -28,7 +28,7 @@ def create_opal_data(data: OpalModel | OpalArrayData) -> dict[str, str]: Returns: A Dict of the Opal data that has just been added to the Dataframe """ # noqa: D301 - log.info("Recieved Opal data.") + log.info("Received Opal data.") raw_data = data.dict() @@ -145,7 +145,7 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]: Returns: dict[str, str]: dictionary with the filename """ # noqa: D301 - log.info("Recieved Opal data.") + log.info("Received DSR data.") data = read_dsr_file(file.file) validate_dsr_data(data) diff --git a/tests/conftest.py b/tests/conftest.py index 854a2a9..00f2075 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,9 +58,10 @@ def dsr_data_path(tmp_path): if field.annotation == str: h5file[field.alias] = "Name or Warning" else: - h5file[field.alias] = np.random.rand( - *field.field_info.extra["shape"] - ).astype("float32") + shape = field.field_info.extra["shape"] + if shape[0] is None: + shape = (10, shape[1]) + h5file[field.alias] = np.random.rand(*shape).astype("float32") # Return the path to the file return file_path diff --git a/tests/test_dsr.py b/tests/test_dsr.py index cbc8ecc..c4221bc 100644 --- a/tests/test_dsr.py +++ b/tests/test_dsr.py @@ -17,7 +17,7 @@ def test_validate_dsr_data(dsr_data): with pytest.raises(HTTPException) as err: validate_dsr_data(dsr_data) - assert err.value.detail == "Invalid size for: Amount, Cost." + assert err.value.detail == "Invalid size or data type for: Amount, Cost." dsr_data.pop("Amount") diff --git a/tests/test_dsr_api.py b/tests/test_dsr_api.py index baa8c39..057a9f8 100644 --- a/tests/test_dsr_api.py +++ b/tests/test_dsr_api.py @@ -39,7 +39,9 @@ def test_post_dsr_api_invalid(dsr_data_path): with open(dsr_data_path, "rb") as dsr_data: response = client.post("/dsr", files={"file": dsr_data}) assert response.status_code == 422 - assert response.json()["detail"] == "Invalid size for: Amount, Cost." + assert ( + response.json()["detail"] == "Invalid size or data type for: Amount, Cost." + ) # Check missing fields raises an error with h5py.File(dsr_data_path, "r+") as dsr_data: