From 0e244d3086e6d275f4e9df51440afa44f5257a59 Mon Sep 17 00:00:00 2001 From: Realcat Date: Tue, 24 Sep 2024 17:44:53 +0000 Subject: [PATCH] update: use fastapi --- api/client.py | 136 ++++++++++++++++++++++++++++------ api/server.py | 188 +++++++++++++++++++++++++++++------------------ requirements.txt | 2 +- 3 files changed, 228 insertions(+), 98 deletions(-) diff --git a/api/client.py b/api/client.py index aa2d20b..54dbdea 100644 --- a/api/client.py +++ b/api/client.py @@ -1,36 +1,112 @@ import argparse import requests import numpy as np -import cv2 import json -import pickle import time from loguru import logger +import sys +import pickle +from typing import Dict + +sys.path.append("..") +API_URL_MATCH = "http://127.0.0.1:8001/v1/match" +API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract" +API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract" + + +def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to generate a match between two images. + + Args: + path0 (str): The path to the first image. + path1 (str): The path to the second image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the generated matches. + The keys are "keypoints0", "keypoints1", "matches0", and "matches1", + and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and + (N, 2), respectively. + """ + files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} + try: + response = requests.post(API_URL_MATCH, files=files) + pred = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + finally: + files["image0"].close() + files["image1"].close() + return pred + -# Update this URL to your server's URL if hosted remotely -API_URL = "http://127.0.0.1:8001/v1/predict" +def send_generate_request1(path0: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to extract features from an image. + Args: + path0 (str): The path to the image. -def send_generate_request(path0, path1): - with open(path0, "rb") as f: - file0 = f.read() + Returns: + Dict[str, np.ndarray]: A dictionary containing the extracted features. + The keys are "keypoints", "descriptors", and "scores", and the + values are ndarrays of shape (N, 2), (N, 128), and (N,), + respectively. + """ + files = {"image": open(path0, "rb")} + try: + response = requests.post(API_URL_EXTRACT, files=files) + pred: Dict[str, np.ndarray] = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + finally: + files["image"].close() + return pred + + +def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to extract features from an image. - with open(path1, "rb") as f: - file1 = f.read() + Args: + image_path (str): The path to the image. - files = { - "image0": ("image0", file0), - "image1": ("image1", file1), + Returns: + Dict[str, np.ndarray]: A dictionary containing the extracted features. + The keys are "keypoints", "descriptors", and "scores", and the + values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively. + """ + data = { + "image_path": image_path, + "max_keypoints": 1024, + "reference_points": [[0.0, 0.0], [1.0, 1.0]], } - response = requests.post(API_URL, files=files) pred = {} - if response.status_code == 200: - response_json = response.json() - pred = json.loads(response_json) - for key in list(pred.keys()): - pred[key] = np.array(pred[key]) - else: - print(f"Error: Response code {response.status_code} - {response.text}") + try: + response = requests.post(API_URL_EXTRACT_V2, json=data) + pred: Dict[str, np.ndarray] = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + except Exception as e: + print(f"An error occurred: {e}") return pred @@ -51,11 +127,23 @@ def send_generate_request(path0, path1): default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", ) args = parser.parse_args() - for i in range(100): + for i in range(10): t1 = time.time() preds = send_generate_request(args.image0, args.image1) t2 = time.time() - logger.info(f"Time cost: {(t2 - t1)} seconds") + logger.info(f"Time cost1: {(t2 - t1)} seconds") + + for i in range(10): + t1 = time.time() + preds = send_generate_request1(args.image0) + t2 = time.time() + logger.info(f"Time cost2: {(t2 - t1)} seconds") + + for i in range(10): + t1 = time.time() + preds = send_generate_request2(args.image0) + t2 = time.time() + logger.info(f"Time cost2: {(t2 - t1)} seconds") - # with open("preds.pkl", "wb") as f: - # pickle.dump(preds, f) + with open("preds1.pkl", "wb") as f: + pickle.dump(preds, f) diff --git a/api/server.py b/api/server.py index c41ad16..ac9d831 100644 --- a/api/server.py +++ b/api/server.py @@ -1,91 +1,133 @@ # server.py -import litserve as ls - -import cv2 -import warnings +from fastapi import FastAPI, File, UploadFile +from fastapi.responses import JSONResponse +import uvicorn +from PIL import Image import numpy as np -from pathlib import Path import sys -import json +from pathlib import Path +from typing import Union sys.path.append("..") -from PIL import Image +from ui.api import ImageMatchingAPI +from ui.utils import load_config, DEVICE +from pydantic import BaseModel -from ui.utils import ( - get_matcher_zoo, - load_config, - DEVICE, - ROOT, -) -from ui.api import ImageMatchingAPI +class ImageInfo(BaseModel): + image_path: str + max_keypoints: int + reference_points: list -# (STEP 1) - DEFINE THE API (compound AI system) -class SimpleLitAPI(ls.LitAPI): - def setup(self, device): - conf = { - "feature": { - "output": "feats-superpoint-n4096-rmax1600", - "model": { - "name": "superpoint", - "nms_radius": 3, - "max_keypoints": 4096, - "keypoint_threshold": 0.005, - }, - "preprocessing": { - "grayscale": True, - "force_resize": True, - "resize_max": 1600, - "width": 640, - "height": 480, - "dfactor": 8, - }, - }, - "matcher": { - "output": "matches-NN-mutual", - "model": { - "name": "nearest_neighbor", - "do_mutual_check": True, - "match_threshold": 0.2, - }, - }, - "dense": False, - } - self.api = ImageMatchingAPI(conf=conf, device=DEVICE) - - def decode_request(self, request): - # Convert the request payload to model input. - return { - "image0": request["image0"].file, - "image1": request["image1"].file, - } - - def predict(self, data): - # Easily build compound systems. Run inference and return the output. - image0 = np.array(Image.open(data["image0"])) - image1 = np.array(Image.open(data["image1"])) - output = self.api(image0, image1) - print(output.keys()) - return output - - def encode_response(self, output): - skip_keys = ["image0_orig", "image1_orig"] +class ImageMatchingService: + def __init__(self, conf: dict, device: str): + self.api = ImageMatchingAPI(conf=conf, device=device) + self.app = FastAPI() + self.register_routes() + + def register_routes(self): + @self.app.post("/v1/match") + async def match( + image0: UploadFile = File(...), image1: UploadFile = File(...) + ): + try: + image0_array = self.load_image(image0) + image1_array = self.load_image(image1) + + output = self.api(image0_array, image1_array) + + skip_keys = ["image0_orig", "image1_orig"] + pred = self.filter_output(output, skip_keys) + + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + @self.app.post("/v1/extract") + async def extract(image: UploadFile = File(...)): + try: + image_array = self.load_image(image) + output = self.api.extract(image_array) + skip_keys = ["descriptors", "image", "image_orig"] + pred = self.filter_output(output, skip_keys) + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + @self.app.post("/v2/extract") + async def extract_v2(image_path: ImageInfo): + img_path = image_path.image_path + try: + safe_path = Path(img_path).resolve(strict=False) + image_array = self.load_image(str(safe_path)) + output = self.api.extract(image_array) + skip_keys = ["descriptors", "image", "image_orig"] + pred = self.filter_output(output, skip_keys) + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: + """ + Reads an image from a file path or an UploadFile object. + + Args: + file_path: A file path or an UploadFile object. + + Returns: + A numpy array representing the image. + """ + if isinstance(file_path, str): + file_path = Path(file_path).resolve(strict=False) + else: + file_path = file_path.file + with Image.open(file_path) as img: + image_array = np.array(img) + return image_array + + def filter_output(self, output: dict, skip_keys: list) -> dict: pred = {} for key, value in output.items(): if key in skip_keys: continue if isinstance(value, np.ndarray): pred[key] = value.tolist() - return json.dumps(pred) + return pred + + def run(self, host: str = "0.0.0.0", port: int = 8001): + uvicorn.run(self.app, host=host, port=port) -# (STEP 2) - START THE SERVER if __name__ == "__main__": - server = ls.LitServer( - SimpleLitAPI(), - accelerator="auto", - api_path="/v1/predict", - max_batch_size=1, - ) - server.run(port=8001) + conf = { + "feature": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "matcher": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "dense": False, + } + + service = ImageMatchingService(conf=conf, device=DEVICE) + service.run() diff --git a/requirements.txt b/requirements.txt index e5e1921..2c0a784 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,4 +36,4 @@ torchvision==0.19.0 roma #dust3r tqdm yacs -litserve +fastapi