Skip to content

Commit

Permalink
update: use fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincentqyw committed Sep 24, 2024
1 parent 1796282 commit 0e244d3
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 98 deletions.
136 changes: 112 additions & 24 deletions api/client.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,112 @@
import argparse
import requests
import numpy as np
import cv2
import json
import pickle
import time
from loguru import logger
import sys
import pickle
from typing import Dict

sys.path.append("..")
API_URL_MATCH = "http://127.0.0.1:8001/v1/match"
API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract"
API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract"


def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to generate a match between two images.
Args:
path0 (str): The path to the first image.
path1 (str): The path to the second image.
Returns:
Dict[str, np.ndarray]: A dictionary containing the generated matches.
The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
(N, 2), respectively.
"""
files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
try:
response = requests.post(API_URL_MATCH, files=files)
pred = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
finally:
files["image0"].close()
files["image1"].close()
return pred


# Update this URL to your server's URL if hosted remotely
API_URL = "http://127.0.0.1:8001/v1/predict"
def send_generate_request1(path0: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to extract features from an image.
Args:
path0 (str): The path to the image.
def send_generate_request(path0, path1):
with open(path0, "rb") as f:
file0 = f.read()
Returns:
Dict[str, np.ndarray]: A dictionary containing the extracted features.
The keys are "keypoints", "descriptors", and "scores", and the
values are ndarrays of shape (N, 2), (N, 128), and (N,),
respectively.
"""
files = {"image": open(path0, "rb")}
try:
response = requests.post(API_URL_EXTRACT, files=files)
pred: Dict[str, np.ndarray] = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
finally:
files["image"].close()
return pred


def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to extract features from an image.
with open(path1, "rb") as f:
file1 = f.read()
Args:
image_path (str): The path to the image.
files = {
"image0": ("image0", file0),
"image1": ("image1", file1),
Returns:
Dict[str, np.ndarray]: A dictionary containing the extracted features.
The keys are "keypoints", "descriptors", and "scores", and the
values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively.
"""
data = {
"image_path": image_path,
"max_keypoints": 1024,
"reference_points": [[0.0, 0.0], [1.0, 1.0]],
}
response = requests.post(API_URL, files=files)
pred = {}
if response.status_code == 200:
response_json = response.json()
pred = json.loads(response_json)
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(f"Error: Response code {response.status_code} - {response.text}")
try:
response = requests.post(API_URL_EXTRACT_V2, json=data)
pred: Dict[str, np.ndarray] = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
except Exception as e:
print(f"An error occurred: {e}")
return pred


Expand All @@ -51,11 +127,23 @@ def send_generate_request(path0, path1):
default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
)
args = parser.parse_args()
for i in range(100):
for i in range(10):
t1 = time.time()
preds = send_generate_request(args.image0, args.image1)
t2 = time.time()
logger.info(f"Time cost: {(t2 - t1)} seconds")
logger.info(f"Time cost1: {(t2 - t1)} seconds")

for i in range(10):
t1 = time.time()
preds = send_generate_request1(args.image0)
t2 = time.time()
logger.info(f"Time cost2: {(t2 - t1)} seconds")

for i in range(10):
t1 = time.time()
preds = send_generate_request2(args.image0)
t2 = time.time()
logger.info(f"Time cost2: {(t2 - t1)} seconds")

# with open("preds.pkl", "wb") as f:
# pickle.dump(preds, f)
with open("preds1.pkl", "wb") as f:
pickle.dump(preds, f)
188 changes: 115 additions & 73 deletions api/server.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,133 @@
# server.py
import litserve as ls

import cv2
import warnings
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import uvicorn
from PIL import Image
import numpy as np
from pathlib import Path
import sys
import json
from pathlib import Path
from typing import Union

sys.path.append("..")
from PIL import Image
from ui.api import ImageMatchingAPI
from ui.utils import load_config, DEVICE
from pydantic import BaseModel

from ui.utils import (
get_matcher_zoo,
load_config,
DEVICE,
ROOT,
)

from ui.api import ImageMatchingAPI
class ImageInfo(BaseModel):
image_path: str
max_keypoints: int
reference_points: list


# (STEP 1) - DEFINE THE API (compound AI system)
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
conf = {
"feature": {
"output": "feats-superpoint-n4096-rmax1600",
"model": {
"name": "superpoint",
"nms_radius": 3,
"max_keypoints": 4096,
"keypoint_threshold": 0.005,
},
"preprocessing": {
"grayscale": True,
"force_resize": True,
"resize_max": 1600,
"width": 640,
"height": 480,
"dfactor": 8,
},
},
"matcher": {
"output": "matches-NN-mutual",
"model": {
"name": "nearest_neighbor",
"do_mutual_check": True,
"match_threshold": 0.2,
},
},
"dense": False,
}
self.api = ImageMatchingAPI(conf=conf, device=DEVICE)

def decode_request(self, request):
# Convert the request payload to model input.
return {
"image0": request["image0"].file,
"image1": request["image1"].file,
}

def predict(self, data):
# Easily build compound systems. Run inference and return the output.
image0 = np.array(Image.open(data["image0"]))
image1 = np.array(Image.open(data["image1"]))
output = self.api(image0, image1)
print(output.keys())
return output

def encode_response(self, output):
skip_keys = ["image0_orig", "image1_orig"]
class ImageMatchingService:
def __init__(self, conf: dict, device: str):
self.api = ImageMatchingAPI(conf=conf, device=device)
self.app = FastAPI()
self.register_routes()

def register_routes(self):
@self.app.post("/v1/match")
async def match(
image0: UploadFile = File(...), image1: UploadFile = File(...)
):
try:
image0_array = self.load_image(image0)
image1_array = self.load_image(image1)

output = self.api(image0_array, image1_array)

skip_keys = ["image0_orig", "image1_orig"]
pred = self.filter_output(output, skip_keys)

return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

@self.app.post("/v1/extract")
async def extract(image: UploadFile = File(...)):
try:
image_array = self.load_image(image)
output = self.api.extract(image_array)
skip_keys = ["descriptors", "image", "image_orig"]
pred = self.filter_output(output, skip_keys)
return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

@self.app.post("/v2/extract")
async def extract_v2(image_path: ImageInfo):
img_path = image_path.image_path
try:
safe_path = Path(img_path).resolve(strict=False)
image_array = self.load_image(str(safe_path))
output = self.api.extract(image_array)
skip_keys = ["descriptors", "image", "image_orig"]
pred = self.filter_output(output, skip_keys)
return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
"""
Reads an image from a file path or an UploadFile object.
Args:
file_path: A file path or an UploadFile object.
Returns:
A numpy array representing the image.
"""
if isinstance(file_path, str):
file_path = Path(file_path).resolve(strict=False)
else:
file_path = file_path.file
with Image.open(file_path) as img:
image_array = np.array(img)
return image_array

def filter_output(self, output: dict, skip_keys: list) -> dict:
pred = {}
for key, value in output.items():
if key in skip_keys:
continue
if isinstance(value, np.ndarray):
pred[key] = value.tolist()
return json.dumps(pred)
return pred

def run(self, host: str = "0.0.0.0", port: int = 8001):
uvicorn.run(self.app, host=host, port=port)


# (STEP 2) - START THE SERVER
if __name__ == "__main__":
server = ls.LitServer(
SimpleLitAPI(),
accelerator="auto",
api_path="/v1/predict",
max_batch_size=1,
)
server.run(port=8001)
conf = {
"feature": {
"output": "feats-superpoint-n4096-rmax1600",
"model": {
"name": "superpoint",
"nms_radius": 3,
"max_keypoints": 4096,
"keypoint_threshold": 0.005,
},
"preprocessing": {
"grayscale": True,
"force_resize": True,
"resize_max": 1600,
"width": 640,
"height": 480,
"dfactor": 8,
},
},
"matcher": {
"output": "matches-NN-mutual",
"model": {
"name": "nearest_neighbor",
"do_mutual_check": True,
"match_threshold": 0.2,
},
},
"dense": False,
}

service = ImageMatchingService(conf=conf, device=DEVICE)
service.run()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ torchvision==0.19.0
roma #dust3r
tqdm
yacs
litserve
fastapi

0 comments on commit 0e244d3

Please sign in to comment.