diff --git a/api/client.py b/api/client.py index c357642..7372a0b 100644 --- a/api/client.py +++ b/api/client.py @@ -1,17 +1,17 @@ -import os import argparse +import os import pickle import time from typing import Dict + import numpy as np import requests -from loguru import logger URL = "http://127.0.0.1:8001" if "REMOTE_URL_RAILWAY" in os.environ: URL = os.environ["REMOTE_URL_RAILWAY"] -logger.info(f"API URL: {URL}") +print(f"API URL: {URL}") API_URL_MATCH = f"{URL}/v1/match" API_URL_EXTRACT = f"{URL}/v1/extract" @@ -135,19 +135,19 @@ def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]: t1 = time.time() preds = send_generate_request(args.image0, args.image1) t2 = time.time() - logger.info(f"Time cost1: {(t2 - t1)} seconds") + print(f"Time cost1: {(t2 - t1)} seconds") for i in range(10): t1 = time.time() preds = send_generate_request1(args.image0) t2 = time.time() - logger.info(f"Time cost2: {(t2 - t1)} seconds") + print(f"Time cost2: {(t2 - t1)} seconds") for i in range(10): t1 = time.time() preds = send_generate_request2(args.image0) t2 = time.time() - logger.info(f"Time cost2: {(t2 - t1)} seconds") + print(f"Time cost2: {(t2 - t1)} seconds") with open("preds.pkl", "wb") as f: pickle.dump(preds, f) diff --git a/api/server.py b/api/server.py index a066a11..75ef608 100644 --- a/api/server.py +++ b/api/server.py @@ -1,20 +1,27 @@ # server.py import sys +import warnings from pathlib import Path -from typing import Union +from typing import Any, Dict, Optional, Union +import cv2 +import matplotlib.pyplot as plt import numpy as np +import torch import uvicorn from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image - from pydantic import BaseModel -sys.path.append(str(Path(__file__).parent.parent)) +sys.path.append(str(Path(__file__).parents[1])) + +from hloc import DEVICE, extract_features, logger, match_dense, match_features +from hloc.utils.viz import add_text, plot_keypoints +from ui.utils import filter_matches, get_feature_model, get_model +from ui.viz import display_matches, fig2im, plot_images -from ui.api import ImageMatchingAPI -from ui.utils import DEVICE +warnings.simplefilter("ignore") class ImageInfo(BaseModel): @@ -23,6 +30,295 @@ class ImageInfo(BaseModel): reference_points: list +class ImageMatchingAPI(torch.nn.Module): + default_conf = { + "ransac": { + "enable": True, + "estimator": "poselib", + "geometry": "homography", + "method": "RANSAC", + "reproj_threshold": 3, + "confidence": 0.9999, + "max_iter": 10000, + }, + } + + def __init__( + self, + conf: dict = {}, + device: str = "cpu", + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ) -> None: + """ + Initializes an instance of the ImageMatchingAPI class. + + Args: + conf (dict): A dictionary containing the configuration parameters. + device (str, optional): The device to use for computation. Defaults to "cpu". + detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015. + max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024. + match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2. + + Returns: + None + """ + super().__init__() + self.device = device + self.conf = {**self.default_conf, **conf} + self._updata_config(detect_threshold, max_keypoints, match_threshold) + self._init_models() + if device == "cuda": + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + logger.info( + f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB" + ) + logger.info( + f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB" + ) + self.pred = None + + def parse_match_config(self, conf): + if conf["dense"]: + return { + **conf, + "matcher": match_dense.confs.get( + conf["matcher"]["model"]["name"] + ), + "dense": True, + } + else: + return { + **conf, + "feature": extract_features.confs.get( + conf["feature"]["model"]["name"] + ), + "matcher": match_features.confs.get( + conf["matcher"]["model"]["name"] + ), + "dense": False, + } + + def _updata_config( + self, + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ): + self.dense = self.conf["dense"] + if self.conf["dense"]: + try: + self.conf["matcher"]["model"][ + "match_threshold" + ] = match_threshold + except TypeError as e: + logger.error(e) + else: + self.conf["feature"]["model"]["max_keypoints"] = max_keypoints + self.conf["feature"]["model"][ + "keypoint_threshold" + ] = detect_threshold + self.extract_conf = self.conf["feature"] + + self.match_conf = self.conf["matcher"] + + def _init_models(self): + # initialize matcher + self.matcher = get_model(self.match_conf) + # initialize extractor + if self.dense: + self.extractor = None + else: + self.extractor = get_feature_model(self.conf["feature"]) + + def _forward(self, img0, img1): + if self.dense: + pred = match_dense.match_images( + self.matcher, + img0, + img1, + self.match_conf["preprocessing"], + device=self.device, + ) + last_fixed = "{}".format( # noqa: F841 + self.match_conf["model"]["name"] + ) + else: + pred0 = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred1 = extract_features.extract( + self.extractor, img1, self.extract_conf["preprocessing"] + ) + pred = match_features.match_images(self.matcher, pred0, pred1) + return pred + + @torch.inference_mode() + def extract( + self, + img0: np.ndarray, + ) -> Dict[str, np.ndarray]: + """Extract features from a single image. + + Args: + img0 (np.ndarray): image + + Returns: + Dict[str, np.ndarray]: feature dict + """ + + pred = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred = { + k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + return pred + + @torch.inference_mode() + def forward( + self, + img0: np.ndarray, + img1: np.ndarray, + ) -> Dict[str, np.ndarray]: + """ + Forward pass of the image matching API. + + Args: + img0: A 3D NumPy array of shape (H, W, C) representing the first image. + Values are in the range [0, 1] and are in RGB mode. + img1: A 3D NumPy array of shape (H, W, C) representing the second image. + Values are in the range [0, 1] and are in RGB mode. + + Returns: + A dictionary containing the following keys: + - image0_orig: The original image 0. + - image1_orig: The original image 1. + - keypoints0_orig: The keypoints detected in image 0. + - keypoints1_orig: The keypoints detected in image 1. + - mkeypoints0_orig: The raw matches between image 0 and image 1. + - mkeypoints1_orig: The raw matches between image 1 and image 0. + - mmkeypoints0_orig: The RANSAC inliers in image 0. + - mmkeypoints1_orig: The RANSAC inliers in image 1. + - mconf: The confidence scores for the raw matches. + - mmconf: The confidence scores for the RANSAC inliers. + """ + # Take as input a pair of images (not a batch) + assert isinstance(img0, np.ndarray) + assert isinstance(img1, np.ndarray) + self.pred = self._forward(img0, img1) + if self.conf["ransac"]["enable"]: + self.pred = self._geometry_check(self.pred) + return self.pred + + def _geometry_check( + self, + pred: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Filter matches using RANSAC. If keypoints are available, filter by keypoints. + If lines are available, filter by lines. If both keypoints and lines are + available, filter by keypoints. + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + See :func:`filter_matches` for the expected keys. + + Returns: + Dict[str, Any]: filtered matches + """ + pred = filter_matches( + pred, + ransac_method=self.conf["ransac"]["method"], + ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"], + ransac_confidence=self.conf["ransac"]["confidence"], + ransac_max_iter=self.conf["ransac"]["max_iter"], + ) + return pred + + def visualize( + self, + log_path: Optional[Path] = None, + ) -> None: + """ + Visualize the matches. + + Args: + log_path (Path, optional): The directory to save the images. Defaults to None. + + Returns: + None + """ + if self.conf["dense"]: + postfix = str(self.conf["matcher"]["model"]["name"]) + else: + postfix = "{}_{}".format( + str(self.conf["feature"]["model"]["name"]), + str(self.conf["matcher"]["model"]["name"]), + ) + titles = [ + "Image 0 - Keypoints", + "Image 1 - Keypoints", + ] + pred: Dict[str, Any] = self.pred + image0: np.ndarray = pred["image0_orig"] + image1: np.ndarray = pred["image1_orig"] + output_keypoints: np.ndarray = plot_images( + [image0, image1], titles=titles, dpi=300 + ) + if ( + "keypoints0_orig" in pred.keys() + and "keypoints1_orig" in pred.keys() + ): + plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) + text: str = ( + f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + + f"# keypoints1: {len(pred['keypoints1_orig'])}" + ) + add_text(0, text, fs=15) + output_keypoints = fig2im(output_keypoints) + # plot images with raw matches + titles = [ + "Image 0 - Raw matched keypoints", + "Image 1 - Raw matched keypoints", + ] + output_matches_raw, num_matches_raw = display_matches( + pred, titles=titles, tag="KPTS_RAW" + ) + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + pred, titles=titles, tag="KPTS_RANSAC" + ) + if log_path is not None: + img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png" + img_matches_raw_path: Path = ( + log_path / f"img_matches_raw_{postfix}.png" + ) + img_matches_ransac_path: Path = ( + log_path / f"img_matches_ransac_{postfix}.png" + ) + cv2.imwrite( + str(img_keypoints_path), + output_keypoints[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_raw_path), + output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_ransac_path), + output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR + ) + plt.close("all") + + class ImageMatchingService: def __init__(self, conf: dict, device: str): self.api = ImageMatchingAPI(conf=conf, device=device) diff --git a/app.py b/app.py index b168e26..31fb805 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path + from ui.app_class import ImageMatchingApp if __name__ == "__main__": diff --git a/test_app_cli.py b/test_app_cli.py index 36144d6..f75ff6b 100644 --- a/test_app_cli.py +++ b/test_app_cli.py @@ -1,12 +1,13 @@ +import sys +from pathlib import Path + import cv2 + from hloc import logger -from ui.utils import ( - get_matcher_zoo, - load_config, - DEVICE, - ROOT, -) -from ui.api import ImageMatchingAPI +from ui.utils import DEVICE, ROOT, get_matcher_zoo, load_config + +sys.path.append(str(Path(__file__).parents[1])) +from api.server import ImageMatchingAPI def test_all(config: dict = None): diff --git a/ui/api.py b/ui/api.py deleted file mode 100644 index bbabd05..0000000 --- a/ui/api.py +++ /dev/null @@ -1,316 +0,0 @@ -import warnings -from pathlib import Path -from typing import Any, Dict, Optional - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch - -from hloc import extract_features, logger, match_dense, match_features -from hloc.utils.viz import add_text, plot_keypoints - -from .utils import ( - ROOT, - filter_matches, - get_feature_model, - get_model, - load_config, -) -from .viz import display_matches, fig2im, plot_images - -warnings.simplefilter("ignore") - - -class ImageMatchingAPI(torch.nn.Module): - default_conf = { - "ransac": { - "enable": True, - "estimator": "poselib", - "geometry": "homography", - "method": "RANSAC", - "reproj_threshold": 3, - "confidence": 0.9999, - "max_iter": 10000, - }, - } - - def __init__( - self, - conf: dict = {}, - device: str = "cpu", - detect_threshold: float = 0.015, - max_keypoints: int = 1024, - match_threshold: float = 0.2, - ) -> None: - """ - Initializes an instance of the ImageMatchingAPI class. - - Args: - conf (dict): A dictionary containing the configuration parameters. - device (str, optional): The device to use for computation. Defaults to "cpu". - detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015. - max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024. - match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2. - - Returns: - None - """ - super().__init__() - self.device = device - self.conf = {**self.default_conf, **conf} - self._updata_config(detect_threshold, max_keypoints, match_threshold) - self._init_models() - if device == "cuda": - memory_allocated = torch.cuda.memory_allocated(device) - memory_reserved = torch.cuda.memory_reserved(device) - logger.info( - f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB" - ) - logger.info( - f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB" - ) - self.pred = None - - def parse_match_config(self, conf): - if conf["dense"]: - return { - **conf, - "matcher": match_dense.confs.get( - conf["matcher"]["model"]["name"] - ), - "dense": True, - } - else: - return { - **conf, - "feature": extract_features.confs.get( - conf["feature"]["model"]["name"] - ), - "matcher": match_features.confs.get( - conf["matcher"]["model"]["name"] - ), - "dense": False, - } - - def _updata_config( - self, - detect_threshold: float = 0.015, - max_keypoints: int = 1024, - match_threshold: float = 0.2, - ): - self.dense = self.conf["dense"] - if self.conf["dense"]: - try: - self.conf["matcher"]["model"][ - "match_threshold" - ] = match_threshold - except TypeError as e: - logger.error(e) - else: - self.conf["feature"]["model"]["max_keypoints"] = max_keypoints - self.conf["feature"]["model"][ - "keypoint_threshold" - ] = detect_threshold - self.extract_conf = self.conf["feature"] - - self.match_conf = self.conf["matcher"] - - def _init_models(self): - # initialize matcher - self.matcher = get_model(self.match_conf) - # initialize extractor - if self.dense: - self.extractor = None - else: - self.extractor = get_feature_model(self.conf["feature"]) - - def _forward(self, img0, img1): - if self.dense: - pred = match_dense.match_images( - self.matcher, - img0, - img1, - self.match_conf["preprocessing"], - device=self.device, - ) - last_fixed = "{}".format( # noqa: F841 - self.match_conf["model"]["name"] - ) - else: - pred0 = extract_features.extract( - self.extractor, img0, self.extract_conf["preprocessing"] - ) - pred1 = extract_features.extract( - self.extractor, img1, self.extract_conf["preprocessing"] - ) - pred = match_features.match_images(self.matcher, pred0, pred1) - return pred - - @torch.inference_mode() - def extract( - self, - img0: np.ndarray, - ) -> Dict[str, np.ndarray]: - """Extract features from a single image. - - Args: - img0 (np.ndarray): image - - Returns: - Dict[str, np.ndarray]: feature dict - """ - - pred = extract_features.extract( - self.extractor, img0, self.extract_conf["preprocessing"] - ) - pred = { - k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v - for k, v in pred.items() - } - return pred - - @torch.inference_mode() - def forward( - self, - img0: np.ndarray, - img1: np.ndarray, - ) -> Dict[str, np.ndarray]: - """ - Forward pass of the image matching API. - - Args: - img0: A 3D NumPy array of shape (H, W, C) representing the first image. - Values are in the range [0, 1] and are in RGB mode. - img1: A 3D NumPy array of shape (H, W, C) representing the second image. - Values are in the range [0, 1] and are in RGB mode. - - Returns: - A dictionary containing the following keys: - - image0_orig: The original image 0. - - image1_orig: The original image 1. - - keypoints0_orig: The keypoints detected in image 0. - - keypoints1_orig: The keypoints detected in image 1. - - mkeypoints0_orig: The raw matches between image 0 and image 1. - - mkeypoints1_orig: The raw matches between image 1 and image 0. - - mmkeypoints0_orig: The RANSAC inliers in image 0. - - mmkeypoints1_orig: The RANSAC inliers in image 1. - - mconf: The confidence scores for the raw matches. - - mmconf: The confidence scores for the RANSAC inliers. - """ - # Take as input a pair of images (not a batch) - assert isinstance(img0, np.ndarray) - assert isinstance(img1, np.ndarray) - self.pred = self._forward(img0, img1) - if self.conf["ransac"]["enable"]: - self.pred = self._geometry_check(self.pred) - return self.pred - - def _geometry_check( - self, - pred: Dict[str, Any], - ) -> Dict[str, Any]: - """ - Filter matches using RANSAC. If keypoints are available, filter by keypoints. - If lines are available, filter by lines. If both keypoints and lines are - available, filter by keypoints. - - Args: - pred (Dict[str, Any]): dict of matches, including original keypoints. - See :func:`filter_matches` for the expected keys. - - Returns: - Dict[str, Any]: filtered matches - """ - pred = filter_matches( - pred, - ransac_method=self.conf["ransac"]["method"], - ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"], - ransac_confidence=self.conf["ransac"]["confidence"], - ransac_max_iter=self.conf["ransac"]["max_iter"], - ) - return pred - - def visualize( - self, - log_path: Optional[Path] = None, - ) -> None: - """ - Visualize the matches. - - Args: - log_path (Path, optional): The directory to save the images. Defaults to None. - - Returns: - None - """ - if self.conf["dense"]: - postfix = str(self.conf["matcher"]["model"]["name"]) - else: - postfix = "{}_{}".format( - str(self.conf["feature"]["model"]["name"]), - str(self.conf["matcher"]["model"]["name"]), - ) - titles = [ - "Image 0 - Keypoints", - "Image 1 - Keypoints", - ] - pred: Dict[str, Any] = self.pred - image0: np.ndarray = pred["image0_orig"] - image1: np.ndarray = pred["image1_orig"] - output_keypoints: np.ndarray = plot_images( - [image0, image1], titles=titles, dpi=300 - ) - if ( - "keypoints0_orig" in pred.keys() - and "keypoints1_orig" in pred.keys() - ): - plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) - text: str = ( - f"# keypoints0: {len(pred['keypoints0_orig'])} \n" - + f"# keypoints1: {len(pred['keypoints1_orig'])}" - ) - add_text(0, text, fs=15) - output_keypoints = fig2im(output_keypoints) - # plot images with raw matches - titles = [ - "Image 0 - Raw matched keypoints", - "Image 1 - Raw matched keypoints", - ] - output_matches_raw, num_matches_raw = display_matches( - pred, titles=titles, tag="KPTS_RAW" - ) - # plot images with ransac matches - titles = [ - "Image 0 - Ransac matched keypoints", - "Image 1 - Ransac matched keypoints", - ] - output_matches_ransac, num_matches_ransac = display_matches( - pred, titles=titles, tag="KPTS_RANSAC" - ) - if log_path is not None: - img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png" - img_matches_raw_path: Path = ( - log_path / f"img_matches_raw_{postfix}.png" - ) - img_matches_ransac_path: Path = ( - log_path / f"img_matches_ransac_{postfix}.png" - ) - cv2.imwrite( - str(img_keypoints_path), - output_keypoints[:, :, ::-1].copy(), # RGB -> BGR - ) - cv2.imwrite( - str(img_matches_raw_path), - output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR - ) - cv2.imwrite( - str(img_matches_ransac_path), - output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR - ) - plt.close("all") - - -if __name__ == "__main__": - config = load_config(ROOT / "ui/config.yaml") - api = ImageMatchingAPI(config) diff --git a/ui/app_class.py b/ui/app_class.py index 761f4cc..94f02f6 100644 --- a/ui/app_class.py +++ b/ui/app_class.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -6,6 +7,8 @@ from easydict import EasyDict as edict from omegaconf import OmegaConf +sys.path.append(str(Path(__file__).parents[1])) + from hloc import flush_logs, read_logs from ui.sfm import SfmEngine from ui.utils import ( diff --git a/ui/sfm.py b/ui/sfm.py index 9176556..2fd90bd 100644 --- a/ui/sfm.py +++ b/ui/sfm.py @@ -1,8 +1,11 @@ import shutil +import sys import tempfile from pathlib import Path from typing import Any, Dict, List +sys.path.append(str(Path(__file__).parents[1])) + from hloc import ( extract_features, logger, @@ -17,7 +20,7 @@ except ImportError: logger.warning("pycolmap not installed, some features may not work") -from .viz import fig2im +from ui.viz import fig2im class SfmEngine: diff --git a/ui/utils.py b/ui/utils.py index 270d12a..cbd935e 100644 --- a/ui/utils.py +++ b/ui/utils.py @@ -2,6 +2,7 @@ import pickle import random import shutil +import sys import time import warnings from itertools import combinations @@ -16,6 +17,8 @@ import psutil from PIL import Image +sys.path.append(str(Path(__file__).parents[1])) + from hloc import ( DEVICE, extract_features, @@ -26,8 +29,7 @@ matchers, ) from hloc.utils.base_model import dynamic_load - -from .viz import display_keypoints, display_matches, fig2im, plot_images +from ui.viz import display_keypoints, display_matches, fig2im, plot_images warnings.simplefilter("ignore") diff --git a/ui/viz.py b/ui/viz.py index d1d53c8..6533f0b 100644 --- a/ui/viz.py +++ b/ui/viz.py @@ -1,3 +1,4 @@ +import sys import typing from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -8,6 +9,8 @@ import numpy as np import seaborn as sns +sys.path.append(str(Path(__file__).parents[1])) + from hloc.utils.viz import add_text, plot_keypoints np.random.seed(1995)