diff --git a/api/__init__.py b/api/__init__.py index e69de29..5a7f0a6 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -0,0 +1,42 @@ +import sys +from typing import List +from pydantic import BaseModel +import base64 +import io +import numpy as np +from fastapi.exceptions import HTTPException +from PIL import Image +from pathlib import Path + +sys.path.append(str(Path(__file__).parents[1])) +from hloc import logger + + +class ImagesInput(BaseModel): + data: List[str] = [] + max_keypoints: List[int] = [] + timestamps: List[str] = [] + grayscale: bool = False + image_hw: List[List[int]] = [[], []] + feature_type: int = 0 + rotates: List[float] = [] + scales: List[float] = [] + reference_points: List[List[float]] = [] + binarize: bool = False + + +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + logger.warning(f"API cannot decode image: {e}") + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e + + +def to_base64_nparray(encoding: str) -> np.ndarray: + return np.array(decode_base64_to_image(encoding)).astype("uint8") diff --git a/api/client.py b/api/client.py index 4fd751c..a281adc 100644 --- a/api/client.py +++ b/api/client.py @@ -9,7 +9,7 @@ import numpy as np import requests -ENDPOINT = "http://127.0.0.1:8001" +ENDPOINT = "http://127.0.0.1:8000" if "REMOTE_URL_RAILWAY" in os.environ: ENDPOINT = os.environ["REMOTE_URL_RAILWAY"] @@ -152,7 +152,8 @@ def send_request_extract( url=API_URL_EXTRACT, **inputs, ) - print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) + # breakpoint() + # print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) # draw matching, debug only if viz: @@ -214,7 +215,7 @@ def get_api_version(): # ) # request extract - for i in range(10): + for i in range(1000): t1 = time.time() preds = send_request_extract(args.image0) t2 = time.time() diff --git a/api/config/api.yaml b/api/config/api.yaml new file mode 100644 index 0000000..06f1caa --- /dev/null +++ b/api/config/api.yaml @@ -0,0 +1,51 @@ +# This file was generated using the `serve build` command on Ray v2.38.0. + +proxy_location: EveryNode +http_options: + host: 0.0.0.0 + port: 8000 + +grpc_options: + port: 9000 + grpc_servicer_functions: [] + +logging_config: + encoding: TEXT + log_level: INFO + logs_dir: null + enable_access_log: true + +applications: +- name: app1 + route_prefix: / + import_path: api.server:service + runtime_env: {} + deployments: + - name: ImageMatchingService + num_replicas: 4 + ray_actor_options: + num_cpus: 2.0 + num_gpus: 1.0 + +api: + feature: + output: feats-superpoint-n4096-rmax1600 + model: + name: superpoint + nms_radius: 3 + max_keypoints: 4096 + keypoint_threshold: 0.005 + preprocessing: + grayscale: True + force_resize: True + resize_max: 1600 + width: 640 + height: 480 + dfactor: 8 + matcher: + output: matches-NN-mutual + model: + name: nearest_neighbor + do_mutual_check: True + match_threshold: 0.2 + dense: False diff --git a/api/server.py b/api/server.py index 1a1edc5..19d7d85 100644 --- a/api/server.py +++ b/api/server.py @@ -1,24 +1,20 @@ # server.py -import base64 -import io -import sys import warnings from pathlib import Path from typing import Any, Dict, Optional, Union +import yaml + +from ray import serve import cv2 import matplotlib.pyplot as plt import numpy as np import torch -import uvicorn from fastapi import FastAPI, File, UploadFile -from fastapi.exceptions import HTTPException from fastapi.responses import JSONResponse from PIL import Image -sys.path.append(str(Path(__file__).parents[1])) - -from api.types import ImagesInput +from api import ImagesInput, to_base64_nparray from hloc import DEVICE, extract_features, logger, match_dense, match_features from hloc.utils.viz import add_text, plot_keypoints from ui import get_version @@ -26,23 +22,7 @@ from ui.viz import display_matches, fig2im, plot_images warnings.simplefilter("ignore") - - -def decode_base64_to_image(encoding): - if encoding.startswith("data:image/"): - encoding = encoding.split(";")[1].split(",")[1] - try: - image = Image.open(io.BytesIO(base64.b64decode(encoding))) - return image - except Exception as e: - logger.warning(f"API cannot decode image: {e}") - raise HTTPException( - status_code=500, detail="Invalid encoded image" - ) from e - - -def to_base64_nparray(encoding: str) -> np.ndarray: - return np.array(decode_base64_to_image(encoding)).astype("uint8") +app = FastAPI() class ImageMatchingAPI(torch.nn.Module): @@ -170,6 +150,17 @@ def _forward(self, img0, img1): pred = match_features.match_images(self.matcher, pred0, pred1) return pred + def _convert_pred(self, pred): + ret = { + k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + ret = { + k: v[0].cpu().detach().numpy() if isinstance(v, list) else v + for k, v in ret.items() + } + return ret + @torch.inference_mode() def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: """Extract features from a single image. @@ -190,17 +181,13 @@ def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: pred = extract_features.extract( self.extractor, img0, self.extract_conf["preprocessing"] ) - pred = { - k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v - for k, v in pred.items() - } + pred = self._convert_pred(pred) # back to origin scale s0 = pred["original_size"] / pred["size"] pred["keypoints_orig"] = ( match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5 ) # TODO: rotate back - binarize = kwargs.get("binarize", False) if binarize: assert "descriptors" in pred @@ -349,88 +336,92 @@ def visualize( plt.close("all") +@serve.deployment( + num_replicas=4, ray_actor_options={"num_cpus": 2, "num_gpus": 1} +) +@serve.ingress(app) class ImageMatchingService: def __init__(self, conf: dict, device: str): self.conf = conf self.api = ImageMatchingAPI(conf=conf, device=device) - self.app = FastAPI() - self.register_routes() - def register_routes(self): + @app.get("/") + def root(self): + return "Hello, world!" - @self.app.get("/version") - async def version(): - return {"version": get_version()} + @app.get("/version") + async def version(self): + return {"version": get_version()} - @self.app.post("/v1/match") - async def match( - image0: UploadFile = File(...), image1: UploadFile = File(...) - ): - """ - Handle the image matching request and return the processed result. + @app.post("/v1/match") + async def match( + self, image0: UploadFile = File(...), image1: UploadFile = File(...) + ): + """ + Handle the image matching request and return the processed result. - Args: - image0 (UploadFile): The first image file for matching. - image1 (UploadFile): The second image file for matching. + Args: + image0 (UploadFile): The first image file for matching. + image1 (UploadFile): The second image file for matching. - Returns: - JSONResponse: A JSON response containing the filtered match results - or an error message in case of failure. - """ - try: - # Load the images from the uploaded files - image0_array = self.load_image(image0) - image1_array = self.load_image(image1) + Returns: + JSONResponse: A JSON response containing the filtered match results + or an error message in case of failure. + """ + try: + # Load the images from the uploaded files + image0_array = self.load_image(image0) + image1_array = self.load_image(image1) - # Perform image matching using the API - output = self.api(image0_array, image1_array) + # Perform image matching using the API + output = self.api(image0_array, image1_array) - # Keys to skip in the output - skip_keys = ["image0_orig", "image1_orig"] + # Keys to skip in the output + skip_keys = ["image0_orig", "image1_orig"] - # Postprocess the output to filter unwanted data - pred = self.postprocess(output, skip_keys) + # Postprocess the output to filter unwanted data + pred = self.postprocess(output, skip_keys) - # Return the filtered prediction as a JSON response - return JSONResponse(content=pred) - except Exception as e: - # Return an error message with status code 500 in case of exception - return JSONResponse(content={"error": str(e)}, status_code=500) + # Return the filtered prediction as a JSON response + return JSONResponse(content=pred) + except Exception as e: + # Return an error message with status code 500 in case of exception + return JSONResponse(content={"error": str(e)}, status_code=500) - @self.app.post("/v1/extract") - async def extract(input_info: ImagesInput): - """ - Extract keypoints and descriptors from images. + @app.post("/v1/extract") + async def extract(self, input_info: ImagesInput): + """ + Extract keypoints and descriptors from images. - Args: - input_info: An object containing the image data and options. + Args: + input_info: An object containing the image data and options. - Returns: - A list of dictionaries containing the keypoints and descriptors. - """ - try: - preds = [] - for i, input_image in enumerate(input_info.data): - # Load the image from the input data - image_array = to_base64_nparray(input_image) - # Extract keypoints and descriptors - output = self.api.extract( - image_array, - max_keypoints=input_info.max_keypoints[i], - binarize=input_info.binarize, - ) - # Do not return the original image and image_orig - # skip_keys = ["image", "image_orig"] - skip_keys = [] - - # Postprocess the output - pred = self.postprocess(output, skip_keys) - preds.append(pred) - # Return the list of extracted features - return JSONResponse(content=preds) - except Exception as e: - # Return an error message if an exception occurs - return JSONResponse(content={"error": str(e)}, status_code=500) + Returns: + A list of dictionaries containing the keypoints and descriptors. + """ + try: + preds = [] + for i, input_image in enumerate(input_info.data): + # Load the image from the input data + image_array = to_base64_nparray(input_image) + # Extract keypoints and descriptors + output = self.api.extract( + image_array, + max_keypoints=input_info.max_keypoints[i], + binarize=input_info.binarize, + ) + # Do not return the original image and image_orig + # skip_keys = ["image", "image_orig"] + skip_keys = [] + + # Postprocess the output + pred = self.postprocess(output, skip_keys) + preds.append(pred) + # Return the list of extracted features + return JSONResponse(content=preds) + except Exception as e: + # Return an error message if an exception occurs + return JSONResponse(content={"error": str(e)}, status_code=500) def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: """ @@ -462,38 +453,23 @@ def postprocess( return pred def run(self, host: str = "0.0.0.0", port: int = 8001): - uvicorn.run(self.app, host=host, port=port) - - -if __name__ == "__main__": - conf = { - "feature": { - "output": "feats-superpoint-n4096-rmax1600", - "model": { - "name": "superpoint", - "nms_radius": 3, - "max_keypoints": 4096, - "keypoint_threshold": 0.005, - }, - "preprocessing": { - "grayscale": True, - "force_resize": True, - "resize_max": 1600, - "width": 640, - "height": 480, - "dfactor": 8, - }, - }, - "matcher": { - "output": "matches-NN-mutual", - "model": { - "name": "nearest_neighbor", - "do_mutual_check": True, - "match_threshold": 0.2, - }, - }, - "dense": False, - } + import uvicorn + uvicorn.run(app, host=host, port=port) + + +def read_config(config_path: Path) -> dict: + with open(config_path, "r") as f: + conf = yaml.safe_load(f) + return conf + + +# api server +conf = read_config(Path(__file__).parent / "config/api.yaml") +service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE) + +# handle = serve.run(service, route_prefix="/") +# serve run api.server_ray:service - service = ImageMatchingService(conf=conf, device=DEVICE) - service.run() +# build to generate config file +# serve build api.server_ray:service -o api/config/ray.yaml +# serve run api/config/ray.yaml diff --git a/api/types.py b/api/types.py deleted file mode 100644 index db17dce..0000000 --- a/api/types.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import List - -from pydantic import BaseModel - - -class ImagesInput(BaseModel): - data: List[str] = [] - max_keypoints: List[int] = [] - timestamps: List[str] = [] - grayscale: bool = False - image_hw: List[List[int]] = [[], []] - feature_type: int = 0 - rotates: List[float] = [] - scales: List[float] = [] - reference_points: List[List[float]] = [] - binarize: bool = False diff --git a/requirements.txt b/requirements.txt index 7bbeed9..a0b13aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,6 @@ roma #dust3r tqdm yacs fastapi +uvicorn +ray +ray[serve]