Skip to content

Commit

Permalink
feat: added counts-by-taxon endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-b-84 committed Aug 6, 2024
1 parent 1006a37 commit 59667f4
Showing 1 changed file with 84 additions and 2 deletions.
86 changes: 84 additions & 2 deletions src/api/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import csv
import json
import os
from typing import Dict, List
from typing import Dict, List, Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
Expand All @@ -12,6 +13,7 @@

Check failure on line 13 in src/api/endpoints.py

View workflow job for this annotation

GitHub Actions / flake8 Lint

isort found an unexpected blank line in imports

RUN_SUMMARY_FILEPATH = "summary.json"
COUNTS_FILEPATH = "cluster_counts_by_taxon.txt"


class InputSchema(BaseModel):
Expand Down Expand Up @@ -139,6 +141,86 @@ async def get_run_summary(session_id: str = Depends(header_scheme)):
) from e


@router.get("/kinfin/counts-by-taxon")
async def get_counts_by_tanon(
request: Request,
session_id: str = Depends(header_scheme),
include_clusters: Optional[str] = Query(None),
exclude_clusters: Optional[str] = Query(None),
min_count: Optional[int] = Query(None),
max_count: Optional[int] = Query(None),
include_taxons: Optional[str] = Query(None),
exclude_taxons: Optional[str] = Query(None),
):
try:
result_dir = query_manager.get_session_dir(session_id)
if not result_dir:
raise HTTPException(status_code=401, detail="Invalid Session ID provided")

file_path = os.path.join(result_dir, COUNTS_FILEPATH)
if not os.path.exists(file_path):
raise HTTPException(
status_code=404,
detail=f"{COUNTS_FILEPATH} File Not Found",
)

included_clusters = (
set(include_clusters.split(",")) if include_clusters else None
)
excluded_clusters = (
set(exclude_clusters.split(",")) if exclude_clusters else None
)
include_taxons_set = set(include_taxons.split(",")) if include_taxons else None
exclude_taxons_set = set(exclude_taxons.split(",")) if exclude_taxons else None

result = {}
with open(file_path, "r", newline="") as file:
reader = csv.DictReader(file, delimiter="\t")
for row in reader:
cluster_id = row["#ID"]

if included_clusters and cluster_id not in included_clusters:
continue
if excluded_clusters and cluster_id in excluded_clusters:
continue

filtered_values = {}
for taxon, count in row.items():
if taxon == "#ID":
continue

count = int(count)

if min_count is not None and count < min_count:
continue

if max_count is not None and count > max_count:
continue

if include_taxons_set and taxon not in include_taxons_set:
continue

if exclude_taxons_set and taxon in exclude_taxons_set:
continue

filtered_values[taxon] = count

if filtered_values:
result[cluster_id] = filtered_values

response = {
"query": str(request.url),
"result": result,
}
return JSONResponse(response)

except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Internal Server Error: {str(e)}",
) from e


@router.get("/plot/{plot_type}")
async def get_plot(
plot_type: str,
Expand Down

0 comments on commit 59667f4

Please sign in to comment.