Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: api #77

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions api/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import pickle
import time
from typing import Dict

import numpy as np
import requests
from loguru import logger

API_URL_MATCH = "http://127.0.0.1:8001/v1/match"
API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract"
API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract"


def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to generate a match between two images.

Args:
path0 (str): The path to the first image.
path1 (str): The path to the second image.

Returns:
Dict[str, np.ndarray]: A dictionary containing the generated matches.
The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
(N, 2), respectively.
"""
files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
try:
response = requests.post(API_URL_MATCH, files=files)
pred = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
finally:
files["image0"].close()
files["image1"].close()
return pred


def send_generate_request1(path0: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to extract features from an image.

Args:
path0 (str): The path to the image.

Returns:
Dict[str, np.ndarray]: A dictionary containing the extracted features.
The keys are "keypoints", "descriptors", and "scores", and the
values are ndarrays of shape (N, 2), (N, 128), and (N,),
respectively.
"""
files = {"image": open(path0, "rb")}
try:
response = requests.post(API_URL_EXTRACT, files=files)
pred: Dict[str, np.ndarray] = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
finally:
files["image"].close()
return pred


def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to extract features from an image.

Args:
image_path (str): The path to the image.

Returns:
Dict[str, np.ndarray]: A dictionary containing the extracted features.
The keys are "keypoints", "descriptors", and "scores", and the
values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively.
"""
data = {
"image_path": image_path,
"max_keypoints": 1024,
"reference_points": [[0.0, 0.0], [1.0, 1.0]],
}
pred = {}
try:
response = requests.post(API_URL_EXTRACT_V2, json=data)
pred: Dict[str, np.ndarray] = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
except Exception as e:
print(f"An error occurred: {e}")
return pred


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Send text to stable audio server and receive generated audio."
)
parser.add_argument(
"--image0",
required=False,
help="Path for the file's melody",
default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg",
)
parser.add_argument(
"--image1",
required=False,
help="Path for the file's melody",
default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
)
args = parser.parse_args()
for i in range(10):
t1 = time.time()
preds = send_generate_request(args.image0, args.image1)
t2 = time.time()
logger.info(f"Time cost1: {(t2 - t1)} seconds")

for i in range(10):
t1 = time.time()
preds = send_generate_request1(args.image0)
t2 = time.time()
logger.info(f"Time cost2: {(t2 - t1)} seconds")

for i in range(10):
t1 = time.time()
preds = send_generate_request2(args.image0)
t2 = time.time()
logger.info(f"Time cost2: {(t2 - t1)} seconds")

with open("preds.pkl", "wb") as f:
pickle.dump(preds, f)
135 changes: 135 additions & 0 deletions api/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# server.py
import sys
from pathlib import Path
from typing import Union

import numpy as np
import uvicorn
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image

sys.path.append("..")
from pydantic import BaseModel

from ui.api import ImageMatchingAPI
from ui.utils import DEVICE


class ImageInfo(BaseModel):
image_path: str
max_keypoints: int
reference_points: list


class ImageMatchingService:
def __init__(self, conf: dict, device: str):
self.api = ImageMatchingAPI(conf=conf, device=device)
self.app = FastAPI()
self.register_routes()

def register_routes(self):
@self.app.post("/v1/match")
async def match(
image0: UploadFile = File(...), image1: UploadFile = File(...)
):
try:
image0_array = self.load_image(image0)
image1_array = self.load_image(image1)

output = self.api(image0_array, image1_array)

skip_keys = ["image0_orig", "image1_orig"]
pred = self.filter_output(output, skip_keys)

return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

@self.app.post("/v1/extract")
async def extract(image: UploadFile = File(...)):
try:
image_array = self.load_image(image)
output = self.api.extract(image_array)
skip_keys = ["descriptors", "image", "image_orig"]
pred = self.filter_output(output, skip_keys)
return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

@self.app.post("/v2/extract")
async def extract_v2(image_path: ImageInfo):
img_path = image_path.image_path
try:
safe_path = Path(img_path).resolve(strict=False)
image_array = self.load_image(str(safe_path))
output = self.api.extract(image_array)
skip_keys = ["descriptors", "image", "image_orig"]
pred = self.filter_output(output, skip_keys)
return JSONResponse(content=pred)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
"""
Reads an image from a file path or an UploadFile object.

Args:
file_path: A file path or an UploadFile object.

Returns:
A numpy array representing the image.
"""
if isinstance(file_path, str):
file_path = Path(file_path).resolve(strict=False)
else:
file_path = file_path.file
with Image.open(file_path) as img:
image_array = np.array(img)
return image_array

def filter_output(self, output: dict, skip_keys: list) -> dict:
pred = {}
for key, value in output.items():
if key in skip_keys:
continue
if isinstance(value, np.ndarray):
pred[key] = value.tolist()
return pred

def run(self, host: str = "0.0.0.0", port: int = 8001):
uvicorn.run(self.app, host=host, port=port)


if __name__ == "__main__":
conf = {
"feature": {
"output": "feats-superpoint-n4096-rmax1600",
"model": {
"name": "superpoint",
"nms_radius": 3,
"max_keypoints": 4096,
"keypoint_threshold": 0.005,
},
"preprocessing": {
"grayscale": True,
"force_resize": True,
"resize_max": 1600,
"width": 640,
"height": 480,
"dfactor": 8,
},
},
"matcher": {
"output": "matches-NN-mutual",
"model": {
"name": "nearest_neighbor",
"do_mutual_check": True,
"match_threshold": 0.2,
},
},
"dense": False,
}

service = ImageMatchingService(conf=conf, device=DEVICE)
service.run()
6 changes: 3 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
python -m flake8 ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
python -m isort ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
python -m black ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
python -m flake8 ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
python -m isort ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
python -m black ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
8 changes: 7 additions & 1 deletion hloc/extractors/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,10 @@ def _init(self, conf):
logger.info("Load SuperPoint model done.")

def _forward(self, data):
return self.net(data, self.conf)
pred = self.net(data, self.conf)
pred = {
"keypoints": pred["keypoints"][0][None],
"scores": pred["scores"][0][None],
"descriptors": pred["descriptors"][0][None],
}
return pred
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ torchvision==0.19.0
roma #dust3r
tqdm
yacs
fastapi
23 changes: 23 additions & 0 deletions ui/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ def _forward(self, img0, img1):
pred = match_features.match_images(self.matcher, pred0, pred1)
return pred

@torch.inference_mode()
def extract(
self,
img0: np.ndarray,
) -> Dict[str, np.ndarray]:
"""Extract features from a single image.

Args:
img0 (np.ndarray): image

Returns:
Dict[str, np.ndarray]: feature dict
"""

pred = extract_features.extract(
self.extractor, img0, self.extract_conf["preprocessing"]
)
pred = {
k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
for k, v in pred.items()
}
return pred

@torch.inference_mode()
def forward(
self,
Expand Down
Loading