-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1796282
commit 0e244d3
Showing
3 changed files
with
228 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,133 @@ | ||
# server.py | ||
import litserve as ls | ||
|
||
import cv2 | ||
import warnings | ||
from fastapi import FastAPI, File, UploadFile | ||
from fastapi.responses import JSONResponse | ||
import uvicorn | ||
from PIL import Image | ||
import numpy as np | ||
from pathlib import Path | ||
import sys | ||
import json | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
sys.path.append("..") | ||
from PIL import Image | ||
from ui.api import ImageMatchingAPI | ||
from ui.utils import load_config, DEVICE | ||
from pydantic import BaseModel | ||
|
||
from ui.utils import ( | ||
get_matcher_zoo, | ||
load_config, | ||
DEVICE, | ||
ROOT, | ||
) | ||
|
||
from ui.api import ImageMatchingAPI | ||
class ImageInfo(BaseModel): | ||
image_path: str | ||
max_keypoints: int | ||
reference_points: list | ||
|
||
|
||
# (STEP 1) - DEFINE THE API (compound AI system) | ||
class SimpleLitAPI(ls.LitAPI): | ||
def setup(self, device): | ||
conf = { | ||
"feature": { | ||
"output": "feats-superpoint-n4096-rmax1600", | ||
"model": { | ||
"name": "superpoint", | ||
"nms_radius": 3, | ||
"max_keypoints": 4096, | ||
"keypoint_threshold": 0.005, | ||
}, | ||
"preprocessing": { | ||
"grayscale": True, | ||
"force_resize": True, | ||
"resize_max": 1600, | ||
"width": 640, | ||
"height": 480, | ||
"dfactor": 8, | ||
}, | ||
}, | ||
"matcher": { | ||
"output": "matches-NN-mutual", | ||
"model": { | ||
"name": "nearest_neighbor", | ||
"do_mutual_check": True, | ||
"match_threshold": 0.2, | ||
}, | ||
}, | ||
"dense": False, | ||
} | ||
self.api = ImageMatchingAPI(conf=conf, device=DEVICE) | ||
|
||
def decode_request(self, request): | ||
# Convert the request payload to model input. | ||
return { | ||
"image0": request["image0"].file, | ||
"image1": request["image1"].file, | ||
} | ||
|
||
def predict(self, data): | ||
# Easily build compound systems. Run inference and return the output. | ||
image0 = np.array(Image.open(data["image0"])) | ||
image1 = np.array(Image.open(data["image1"])) | ||
output = self.api(image0, image1) | ||
print(output.keys()) | ||
return output | ||
|
||
def encode_response(self, output): | ||
skip_keys = ["image0_orig", "image1_orig"] | ||
class ImageMatchingService: | ||
def __init__(self, conf: dict, device: str): | ||
self.api = ImageMatchingAPI(conf=conf, device=device) | ||
self.app = FastAPI() | ||
self.register_routes() | ||
|
||
def register_routes(self): | ||
@self.app.post("/v1/match") | ||
async def match( | ||
image0: UploadFile = File(...), image1: UploadFile = File(...) | ||
): | ||
try: | ||
image0_array = self.load_image(image0) | ||
image1_array = self.load_image(image1) | ||
|
||
output = self.api(image0_array, image1_array) | ||
|
||
skip_keys = ["image0_orig", "image1_orig"] | ||
pred = self.filter_output(output, skip_keys) | ||
|
||
return JSONResponse(content=pred) | ||
except Exception as e: | ||
return JSONResponse(content={"error": str(e)}, status_code=500) | ||
|
||
@self.app.post("/v1/extract") | ||
async def extract(image: UploadFile = File(...)): | ||
try: | ||
image_array = self.load_image(image) | ||
output = self.api.extract(image_array) | ||
skip_keys = ["descriptors", "image", "image_orig"] | ||
pred = self.filter_output(output, skip_keys) | ||
return JSONResponse(content=pred) | ||
except Exception as e: | ||
return JSONResponse(content={"error": str(e)}, status_code=500) | ||
|
||
@self.app.post("/v2/extract") | ||
async def extract_v2(image_path: ImageInfo): | ||
img_path = image_path.image_path | ||
try: | ||
safe_path = Path(img_path).resolve(strict=False) | ||
image_array = self.load_image(str(safe_path)) | ||
output = self.api.extract(image_array) | ||
skip_keys = ["descriptors", "image", "image_orig"] | ||
pred = self.filter_output(output, skip_keys) | ||
return JSONResponse(content=pred) | ||
except Exception as e: | ||
return JSONResponse(content={"error": str(e)}, status_code=500) | ||
|
||
def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: | ||
""" | ||
Reads an image from a file path or an UploadFile object. | ||
Args: | ||
file_path: A file path or an UploadFile object. | ||
Returns: | ||
A numpy array representing the image. | ||
""" | ||
if isinstance(file_path, str): | ||
file_path = Path(file_path).resolve(strict=False) | ||
else: | ||
file_path = file_path.file | ||
with Image.open(file_path) as img: | ||
image_array = np.array(img) | ||
return image_array | ||
|
||
def filter_output(self, output: dict, skip_keys: list) -> dict: | ||
pred = {} | ||
for key, value in output.items(): | ||
if key in skip_keys: | ||
continue | ||
if isinstance(value, np.ndarray): | ||
pred[key] = value.tolist() | ||
return json.dumps(pred) | ||
return pred | ||
|
||
def run(self, host: str = "0.0.0.0", port: int = 8001): | ||
uvicorn.run(self.app, host=host, port=port) | ||
|
||
|
||
# (STEP 2) - START THE SERVER | ||
if __name__ == "__main__": | ||
server = ls.LitServer( | ||
SimpleLitAPI(), | ||
accelerator="auto", | ||
api_path="/v1/predict", | ||
max_batch_size=1, | ||
) | ||
server.run(port=8001) | ||
conf = { | ||
"feature": { | ||
"output": "feats-superpoint-n4096-rmax1600", | ||
"model": { | ||
"name": "superpoint", | ||
"nms_radius": 3, | ||
"max_keypoints": 4096, | ||
"keypoint_threshold": 0.005, | ||
}, | ||
"preprocessing": { | ||
"grayscale": True, | ||
"force_resize": True, | ||
"resize_max": 1600, | ||
"width": 640, | ||
"height": 480, | ||
"dfactor": 8, | ||
}, | ||
}, | ||
"matcher": { | ||
"output": "matches-NN-mutual", | ||
"model": { | ||
"name": "nearest_neighbor", | ||
"do_mutual_check": True, | ||
"match_threshold": 0.2, | ||
}, | ||
}, | ||
"dense": False, | ||
} | ||
|
||
service = ImageMatchingService(conf=conf, device=DEVICE) | ||
service.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,4 +36,4 @@ torchvision==0.19.0 | |
roma #dust3r | ||
tqdm | ||
yacs | ||
litserve | ||
fastapi |