From e8eea76fc0feada5451fd606a772cfa6082ab308 Mon Sep 17 00:00:00 2001 From: Adrian D'Alessandro Date: Mon, 25 Sep 2023 17:22:37 +0100 Subject: [PATCH] Update dsr get method to take non-whitespace column names --- datahub/main.py | 17 ++++++++--------- tests/test_dsr_api.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/datahub/main.py b/datahub/main.py index 9896a11..c41eb95 100644 --- a/datahub/main.py +++ b/datahub/main.py @@ -156,21 +156,22 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]: @app.get("/dsr", response_class=ORJSONResponse) def get_dsr_data( - start: int = len(dt.dsr_data) - 1, end: int | None = None, col: str | None = None + start: int = -1, end: int | None = None, col: str | None = None ) -> ORJSONResponse: """GET method function for getting DSR data as JSON. It takes optional query parameters of: - - `start`: Starting index for exported list - - `end`: Last index that will be included in exported list + - `start`: Starting index for exported list. Defaults to -1 for the most recent + entry only. + - `end`: Last index that will be included in exported list. + - `col`: A comma-separated list of which columns/keys within the data to get. + These values are all lower-case and spaces are replaced by underscores. And returns a dictionary containing the DSR data in JSON format. This can be converted back to a DataFrame using the following: `pd.DataFrame(**data)` - TODO: Ensure data is json serializable or returned in binary format - \f Args: @@ -199,9 +200,7 @@ def get_dsr_data( columns = col.lower().split(",") for col_name in columns: - dsr_columns = [x.lower() for x in dsr_headers.keys()] - - if col_name not in dsr_columns: + if col_name not in dsr_headers.values(): message = "One or more of the specified columns are invalid." log.error(message) raise HTTPException(status_code=400, detail=message) @@ -211,7 +210,7 @@ def get_dsr_data( for frame in filtered_index_data: filtered_keys = {} for key in frame.keys(): - if key.lower() in columns: + if dsr_headers[key.title()] in columns: filtered_keys[key] = frame[key] filtered_data.append(filtered_keys) diff --git a/tests/test_dsr_api.py b/tests/test_dsr_api.py index 7dda35b..baa8c39 100644 --- a/tests/test_dsr_api.py +++ b/tests/test_dsr_api.py @@ -96,7 +96,7 @@ def test_get_dsr_api(dsr_data): assert len(response.json()["data"][0].keys()) == 1 assert "Activities" in response.json()["data"][0].keys() - response = client.get("/dsr?col=activity types,kwh cost") + response = client.get("/dsr?col=activity_types,kwh_cost") assert len(response.json()["data"][0].keys()) == 2 assert "Activity Types" in response.json()["data"][0].keys() assert "kWh Cost" in response.json()["data"][0].keys()