diff --git a/api/client.py b/api/client.py new file mode 100644 index 0000000..4206924 --- /dev/null +++ b/api/client.py @@ -0,0 +1,147 @@ +import argparse +import pickle +import time +from typing import Dict + +import numpy as np +import requests +from loguru import logger + +API_URL_MATCH = "http://127.0.0.1:8001/v1/match" +API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract" +API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract" + + +def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to generate a match between two images. + + Args: + path0 (str): The path to the first image. + path1 (str): The path to the second image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the generated matches. + The keys are "keypoints0", "keypoints1", "matches0", and "matches1", + and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and + (N, 2), respectively. + """ + files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} + try: + response = requests.post(API_URL_MATCH, files=files) + pred = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + finally: + files["image0"].close() + files["image1"].close() + return pred + + +def send_generate_request1(path0: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to extract features from an image. + + Args: + path0 (str): The path to the image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the extracted features. + The keys are "keypoints", "descriptors", and "scores", and the + values are ndarrays of shape (N, 2), (N, 128), and (N,), + respectively. + """ + files = {"image": open(path0, "rb")} + try: + response = requests.post(API_URL_EXTRACT, files=files) + pred: Dict[str, np.ndarray] = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + finally: + files["image"].close() + return pred + + +def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to extract features from an image. + + Args: + image_path (str): The path to the image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the extracted features. + The keys are "keypoints", "descriptors", and "scores", and the + values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively. + """ + data = { + "image_path": image_path, + "max_keypoints": 1024, + "reference_points": [[0.0, 0.0], [1.0, 1.0]], + } + pred = {} + try: + response = requests.post(API_URL_EXTRACT_V2, json=data) + pred: Dict[str, np.ndarray] = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + except Exception as e: + print(f"An error occurred: {e}") + return pred + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Send text to stable audio server and receive generated audio." + ) + parser.add_argument( + "--image0", + required=False, + help="Path for the file's melody", + default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg", + ) + parser.add_argument( + "--image1", + required=False, + help="Path for the file's melody", + default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", + ) + args = parser.parse_args() + for i in range(10): + t1 = time.time() + preds = send_generate_request(args.image0, args.image1) + t2 = time.time() + logger.info(f"Time cost1: {(t2 - t1)} seconds") + + for i in range(10): + t1 = time.time() + preds = send_generate_request1(args.image0) + t2 = time.time() + logger.info(f"Time cost2: {(t2 - t1)} seconds") + + for i in range(10): + t1 = time.time() + preds = send_generate_request2(args.image0) + t2 = time.time() + logger.info(f"Time cost2: {(t2 - t1)} seconds") + + with open("preds.pkl", "wb") as f: + pickle.dump(preds, f) diff --git a/api/server.py b/api/server.py new file mode 100644 index 0000000..068c662 --- /dev/null +++ b/api/server.py @@ -0,0 +1,135 @@ +# server.py +import sys +from pathlib import Path +from typing import Union + +import numpy as np +import uvicorn +from fastapi import FastAPI, File, UploadFile +from fastapi.responses import JSONResponse +from PIL import Image + +sys.path.append("..") +from pydantic import BaseModel + +from ui.api import ImageMatchingAPI +from ui.utils import DEVICE + + +class ImageInfo(BaseModel): + image_path: str + max_keypoints: int + reference_points: list + + +class ImageMatchingService: + def __init__(self, conf: dict, device: str): + self.api = ImageMatchingAPI(conf=conf, device=device) + self.app = FastAPI() + self.register_routes() + + def register_routes(self): + @self.app.post("/v1/match") + async def match( + image0: UploadFile = File(...), image1: UploadFile = File(...) + ): + try: + image0_array = self.load_image(image0) + image1_array = self.load_image(image1) + + output = self.api(image0_array, image1_array) + + skip_keys = ["image0_orig", "image1_orig"] + pred = self.filter_output(output, skip_keys) + + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + @self.app.post("/v1/extract") + async def extract(image: UploadFile = File(...)): + try: + image_array = self.load_image(image) + output = self.api.extract(image_array) + skip_keys = ["descriptors", "image", "image_orig"] + pred = self.filter_output(output, skip_keys) + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + @self.app.post("/v2/extract") + async def extract_v2(image_path: ImageInfo): + img_path = image_path.image_path + try: + safe_path = Path(img_path).resolve(strict=False) + image_array = self.load_image(str(safe_path)) + output = self.api.extract(image_array) + skip_keys = ["descriptors", "image", "image_orig"] + pred = self.filter_output(output, skip_keys) + return JSONResponse(content=pred) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: + """ + Reads an image from a file path or an UploadFile object. + + Args: + file_path: A file path or an UploadFile object. + + Returns: + A numpy array representing the image. + """ + if isinstance(file_path, str): + file_path = Path(file_path).resolve(strict=False) + else: + file_path = file_path.file + with Image.open(file_path) as img: + image_array = np.array(img) + return image_array + + def filter_output(self, output: dict, skip_keys: list) -> dict: + pred = {} + for key, value in output.items(): + if key in skip_keys: + continue + if isinstance(value, np.ndarray): + pred[key] = value.tolist() + return pred + + def run(self, host: str = "0.0.0.0", port: int = 8001): + uvicorn.run(self.app, host=host, port=port) + + +if __name__ == "__main__": + conf = { + "feature": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "matcher": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "dense": False, + } + + service = ImageMatchingService(conf=conf, device=DEVICE) + service.run() diff --git a/format.sh b/format.sh index 377f041..ada7140 100644 --- a/format.sh +++ b/format.sh @@ -1,3 +1,3 @@ -python -m flake8 ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py -python -m isort ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py -python -m black ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py \ No newline at end of file +python -m flake8 ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py +python -m isort ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py +python -m black ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py \ No newline at end of file diff --git a/hloc/extractors/superpoint.py b/hloc/extractors/superpoint.py index ee61839..d47f33a 100644 --- a/hloc/extractors/superpoint.py +++ b/hloc/extractors/superpoint.py @@ -48,4 +48,10 @@ def _init(self, conf): logger.info("Load SuperPoint model done.") def _forward(self, data): - return self.net(data, self.conf) + pred = self.net(data, self.conf) + pred = { + "keypoints": pred["keypoints"][0][None], + "scores": pred["scores"][0][None], + "descriptors": pred["descriptors"][0][None], + } + return pred diff --git a/requirements.txt b/requirements.txt index 4ec3af7..2c0a784 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ torchvision==0.19.0 roma #dust3r tqdm yacs +fastapi diff --git a/ui/api.py b/ui/api.py index bcbd697..bbabd05 100644 --- a/ui/api.py +++ b/ui/api.py @@ -147,6 +147,29 @@ def _forward(self, img0, img1): pred = match_features.match_images(self.matcher, pred0, pred1) return pred + @torch.inference_mode() + def extract( + self, + img0: np.ndarray, + ) -> Dict[str, np.ndarray]: + """Extract features from a single image. + + Args: + img0 (np.ndarray): image + + Returns: + Dict[str, np.ndarray]: feature dict + """ + + pred = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred = { + k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + return pred + @torch.inference_mode() def forward( self,