Skip to content

Commit

Permalink
Update dsr get method to take non-whitespace column names
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianDAlessandro committed Sep 25, 2023
1 parent 5345907 commit e8eea76
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
17 changes: 8 additions & 9 deletions datahub/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,22 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]:

@app.get("/dsr", response_class=ORJSONResponse)
def get_dsr_data(
start: int = len(dt.dsr_data) - 1, end: int | None = None, col: str | None = None
start: int = -1, end: int | None = None, col: str | None = None
) -> ORJSONResponse:
"""GET method function for getting DSR data as JSON.
It takes optional query parameters of:
- `start`: Starting index for exported list
- `end`: Last index that will be included in exported list
- `start`: Starting index for exported list. Defaults to -1 for the most recent
entry only.
- `end`: Last index that will be included in exported list.
- `col`: A comma-separated list of which columns/keys within the data to get.
These values are all lower-case and spaces are replaced by underscores.
And returns a dictionary containing the DSR data in JSON format.
This can be converted back to a DataFrame using the following:
`pd.DataFrame(**data)`
TODO: Ensure data is json serializable or returned in binary format
\f
Args:
Expand Down Expand Up @@ -199,9 +200,7 @@ def get_dsr_data(
columns = col.lower().split(",")

for col_name in columns:
dsr_columns = [x.lower() for x in dsr_headers.keys()]

if col_name not in dsr_columns:
if col_name not in dsr_headers.values():
message = "One or more of the specified columns are invalid."
log.error(message)
raise HTTPException(status_code=400, detail=message)
Expand All @@ -211,7 +210,7 @@ def get_dsr_data(
for frame in filtered_index_data:
filtered_keys = {}
for key in frame.keys():
if key.lower() in columns:
if dsr_headers[key.title()] in columns:
filtered_keys[key] = frame[key]
filtered_data.append(filtered_keys)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dsr_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_get_dsr_api(dsr_data):
assert len(response.json()["data"][0].keys()) == 1
assert "Activities" in response.json()["data"][0].keys()

response = client.get("/dsr?col=activity types,kwh cost")
response = client.get("/dsr?col=activity_types,kwh_cost")
assert len(response.json()["data"][0].keys()) == 2
assert "Activity Types" in response.json()["data"][0].keys()
assert "kWh Cost" in response.json()["data"][0].keys()
Expand Down

0 comments on commit e8eea76

Please sign in to comment.