From 7d9d88c0b455da50aa6e1764fd989c341a0145bd Mon Sep 17 00:00:00 2001 From: Realcat Date: Sun, 20 Oct 2024 13:16:50 +0000 Subject: [PATCH] update: api and gradio -> 5.x --- api/client.py | 179 ++++++++++++------- api/server.py | 116 +++++++++--- api/test/CMakeLists.txt | 16 ++ api/test/build_and_run.sh | 16 ++ api/test/client.cpp | 82 +++++++++ api/test/helper.h | 362 ++++++++++++++++++++++++++++++++++++++ api/types.py | 16 ++ requirements.txt | 3 +- ui/__init__.py | 5 + ui/app_class.py | 21 +-- 10 files changed, 700 insertions(+), 116 deletions(-) create mode 100644 api/test/CMakeLists.txt create mode 100644 api/test/build_and_run.sh create mode 100644 api/test/client.cpp create mode 100644 api/test/helper.h create mode 100644 api/types.py diff --git a/api/client.py b/api/client.py index 7372a0b..20274a1 100644 --- a/api/client.py +++ b/api/client.py @@ -1,24 +1,102 @@ import argparse +import base64 import os import pickle import time -from typing import Dict +from typing import Dict, List +import cv2 import numpy as np import requests -URL = "http://127.0.0.1:8001" +ENDPOINT = "http://127.0.0.1:8001" if "REMOTE_URL_RAILWAY" in os.environ: URL = os.environ["REMOTE_URL_RAILWAY"] -print(f"API URL: {URL}") +print(f"API ENDPOINT: {ENDPOINT}") -API_URL_MATCH = f"{URL}/v1/match" -API_URL_EXTRACT = f"{URL}/v1/extract" -API_URL_EXTRACT_V2 = f"{URL}/v2/extract" +API_VERSION = f"{ENDPOINT}/version" +API_URL_MATCH = f"{ENDPOINT}/v1/match" +API_URL_EXTRACT = f"{ENDPOINT}/v1/extract" -def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]: +def read_image(path: str) -> str: + """ + Read an image from a file, encode it as a JPEG and then as a base64 string. + + Args: + path (str): The path to the image to read. + + Returns: + str: The base64 encoded image. + """ + # Read the image from the file + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + # Encode the image as a png, NO COMPRESSION!!! + retval, buffer = cv2.imencode(".png", img) + + # Encode the JPEG as a base64 string + b64img = base64.b64encode(buffer).decode("utf-8") + + return b64img + + +def do_api_requests(url=API_URL_EXTRACT, **kwargs): + """ + Helper function to send an API request to the image matching service. + + Args: + url (str): The URL of the API endpoint to use. Defaults to the + feature extraction endpoint. + **kwargs: Additional keyword arguments to pass to the API. + + Returns: + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, ?), + and (N,), respectively. + """ + # Set up the request body + reqbody = { + # List of image data base64 encoded + "data": [], + # List of maximum number of keypoints to extract from each image + "max_keypoints": [100, 100], + # List of timestamps for each image (not used?) + "timestamps": ["0", "1"], + # Whether to convert the images to grayscale + "grayscale": 0, + # List of image height and width + "image_hw": [[640, 480], [320, 240]], + # Type of feature to extract + "feature_type": 0, + # List of rotation angles for each image + "rotates": [0.0, 0.0], + # List of scale factors for each image + "scales": [1.0, 1.0], + # List of reference points for each image (not used) + "reference_points": [[640, 480], [320, 240]], + # Whether to binarize the descriptors + "binarize": True, + } + # Update the request body with the additional keyword arguments + reqbody.update(kwargs) + try: + # Send the request + r = requests.post(url, json=reqbody) + if r.status_code == 200: + # Return the response + return r.json() + else: + # Print an error message if the response code is not 200 + print(f"Error: Response code {r.status_code} - {r.text}") + except Exception as e: + # Print an error message if an exception occurs + print(f"An error occurred: {e}") + + +def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]: """ Send a request to the API to generate a match between two images. @@ -34,6 +112,7 @@ def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]: """ files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} try: + # TODO: replace files with post json response = requests.post(API_URL_MATCH, files=files) pred = {} if response.status_code == 200: @@ -50,68 +129,37 @@ def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]: return pred -def send_generate_request1(path0: str) -> Dict[str, np.ndarray]: +def send_request_extract(input_images: str) -> List[Dict[str, np.ndarray]]: """ Send a request to the API to extract features from an image. Args: - path0 (str): The path to the image. + input_images (str): The path to the image. Returns: - Dict[str, np.ndarray]: A dictionary containing the extracted features. - The keys are "keypoints", "descriptors", and "scores", and the - values are ndarrays of shape (N, 2), (N, 128), and (N,), - respectively. + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, 128), + and (N,), respectively. """ - files = {"image": open(path0, "rb")} - try: - response = requests.post(API_URL_EXTRACT, files=files) - pred: Dict[str, np.ndarray] = {} - if response.status_code == 200: - pred = response.json() - for key in list(pred.keys()): - pred[key] = np.array(pred[key]) - else: - print( - f"Error: Response code {response.status_code} - {response.text}" - ) - finally: - files["image"].close() - return pred - - -def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]: - """ - Send a request to the API to extract features from an image. + image_data = read_image(input_images) + inputs = { + "data": [image_data], + } + response = do_api_requests( + url=API_URL_EXTRACT, + **inputs, + ) + print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) + return response - Args: - image_path (str): The path to the image. - Returns: - Dict[str, np.ndarray]: A dictionary containing the extracted features. - The keys are "keypoints", "descriptors", and "scores", and the - values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively. - """ - data = { - "image_path": image_path, - "max_keypoints": 1024, - "reference_points": [[0.0, 0.0], [1.0, 1.0]], - } - pred = {} +def get_api_version(): try: - response = requests.post(API_URL_EXTRACT_V2, json=data) - pred: Dict[str, np.ndarray] = {} - if response.status_code == 200: - pred = response.json() - for key in list(pred.keys()): - pred[key] = np.array(pred[key]) - else: - print( - f"Error: Response code {response.status_code} - {response.text}" - ) + response = requests.get(API_VERSION).json() + print("API VERSION: {}".format(response["version"])) except Exception as e: print(f"An error occurred: {e}") - return pred if __name__ == "__main__": @@ -131,23 +179,24 @@ def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]: default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", ) args = parser.parse_args() - for i in range(10): - t1 = time.time() - preds = send_generate_request(args.image0, args.image1) - t2 = time.time() - print(f"Time cost1: {(t2 - t1)} seconds") + # get api version + get_api_version() + + # request match for i in range(10): t1 = time.time() - preds = send_generate_request1(args.image0) + preds = send_request_match(args.image0, args.image1) t2 = time.time() - print(f"Time cost2: {(t2 - t1)} seconds") + print(f"Time cost1: {(t2 - t1)} seconds") + # request extract for i in range(10): t1 = time.time() - preds = send_generate_request2(args.image0) + preds = send_request_extract(args.image0) t2 = time.time() print(f"Time cost2: {(t2 - t1)} seconds") + # dump preds with open("preds.pkl", "wb") as f: pickle.dump(preds, f) diff --git a/api/server.py b/api/server.py index 75ef608..77092f1 100644 --- a/api/server.py +++ b/api/server.py @@ -1,4 +1,6 @@ # server.py +import base64 +import io import sys import warnings from pathlib import Path @@ -10,24 +12,37 @@ import torch import uvicorn from fastapi import FastAPI, File, UploadFile +from fastapi.exceptions import HTTPException from fastapi.responses import JSONResponse from PIL import Image -from pydantic import BaseModel sys.path.append(str(Path(__file__).parents[1])) +from api.types import ImagesInput from hloc import DEVICE, extract_features, logger, match_dense, match_features from hloc.utils.viz import add_text, plot_keypoints +from ui import get_version from ui.utils import filter_matches, get_feature_model, get_model from ui.viz import display_matches, fig2im, plot_images warnings.simplefilter("ignore") -class ImageInfo(BaseModel): - image_path: str - max_keypoints: int - reference_points: list +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + logger.warning(f"API cannot decode image: {e}") + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e + + +def to_base64_nparray(encoding: str) -> np.ndarray: + return np.array(decode_base64_to_image(encoding)).astype("uint8") class ImageMatchingAPI(torch.nn.Module): @@ -156,10 +171,7 @@ def _forward(self, img0, img1): return pred @torch.inference_mode() - def extract( - self, - img0: np.ndarray, - ) -> Dict[str, np.ndarray]: + def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: """Extract features from a single image. Args: @@ -169,6 +181,12 @@ def extract( Dict[str, np.ndarray]: feature dict """ + # setting prams + self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512) + self.extractor.conf["keypoint_threshold"] = kwargs.get( + "keypoint_threshold", 0.0 + ) + pred = extract_features.extract( self.extractor, img0, self.extract_conf["preprocessing"] ) @@ -176,6 +194,11 @@ def extract( k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v for k, v in pred.items() } + binarize = kwargs.get("binarize", False) + if binarize: + assert "descriptors" in pred + pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8) + pred["descriptors"] = pred["descriptors"].T # N x DIM return pred @torch.inference_mode() @@ -321,50 +344,83 @@ def visualize( class ImageMatchingService: def __init__(self, conf: dict, device: str): + self.conf = conf self.api = ImageMatchingAPI(conf=conf, device=device) self.app = FastAPI() self.register_routes() def register_routes(self): + + @self.app.get("/version") + async def version(): + return {"version": get_version()} + @self.app.post("/v1/match") async def match( image0: UploadFile = File(...), image1: UploadFile = File(...) ): + """ + Handle the image matching request and return the processed result. + + Args: + image0 (UploadFile): The first image file for matching. + image1 (UploadFile): The second image file for matching. + + Returns: + JSONResponse: A JSON response containing the filtered match results + or an error message in case of failure. + """ try: + # Load the images from the uploaded files image0_array = self.load_image(image0) image1_array = self.load_image(image1) + # Perform image matching using the API output = self.api(image0_array, image1_array) + # Keys to skip in the output skip_keys = ["image0_orig", "image1_orig"] - pred = self.filter_output(output, skip_keys) + # Postprocess the output to filter unwanted data + pred = self.postprocess(output, skip_keys) + + # Return the filtered prediction as a JSON response return JSONResponse(content=pred) except Exception as e: + # Return an error message with status code 500 in case of exception return JSONResponse(content={"error": str(e)}, status_code=500) @self.app.post("/v1/extract") - async def extract(image: UploadFile = File(...)): - try: - image_array = self.load_image(image) - output = self.api.extract(image_array) - skip_keys = ["descriptors", "image", "image_orig"] - pred = self.filter_output(output, skip_keys) - return JSONResponse(content=pred) - except Exception as e: - return JSONResponse(content={"error": str(e)}, status_code=500) + async def extract(input_info: ImagesInput): + """ + Extract keypoints and descriptors from images. - @self.app.post("/v2/extract") - async def extract_v2(image_path: ImageInfo): - img_path = image_path.image_path + Args: + input_info: An object containing the image data and options. + + Returns: + A list of dictionaries containing the keypoints and descriptors. + """ try: - safe_path = Path(img_path).resolve(strict=False) - image_array = self.load_image(str(safe_path)) - output = self.api.extract(image_array) - skip_keys = ["descriptors", "image", "image_orig"] - pred = self.filter_output(output, skip_keys) - return JSONResponse(content=pred) + preds = [] + for i, input_image in enumerate(input_info.data): + # Load the image from the input data + image_array = to_base64_nparray(input_image) + # Extract keypoints and descriptors + output = self.api.extract( + image_array, + max_keypoints=input_info.max_keypoints[i], + binarize=input_info.binarize, + ) + # Do not return the original image and image_orig + skip_keys = ["image", "image_orig"] + # Postprocess the output + pred = self.postprocess(output, skip_keys) + preds.append(pred) + # Return the list of extracted features + return JSONResponse(content=preds) except Exception as e: + # Return an error message if an exception occurs return JSONResponse(content={"error": str(e)}, status_code=500) def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: @@ -385,7 +441,9 @@ def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: image_array = np.array(img) return image_array - def filter_output(self, output: dict, skip_keys: list) -> dict: + def postprocess( + self, output: dict, skip_keys: list, binarize: bool = True + ) -> dict: pred = {} for key, value in output.items(): if key in skip_keys: diff --git a/api/test/CMakeLists.txt b/api/test/CMakeLists.txt new file mode 100644 index 0000000..200c17d --- /dev/null +++ b/api/test/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(imatchui) + +set(OpenCV_DIR /usr/include/opencv4) +find_package(OpenCV REQUIRED) + +find_package(Boost REQUIRED COMPONENTS system) +if(Boost_FOUND) + include_directories(${Boost_INCLUDE_DIRS}) +endif() + +add_executable(client client.cpp) + +target_include_directories(client PRIVATE ${Boost_LIBRARIES} ${OpenCV_INCLUDE_DIRS}) + +target_link_libraries(client PRIVATE curl jsoncpp b64 ${OpenCV_LIBS}) diff --git a/api/test/build_and_run.sh b/api/test/build_and_run.sh new file mode 100644 index 0000000..40921bb --- /dev/null +++ b/api/test/build_and_run.sh @@ -0,0 +1,16 @@ +# g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main +# sudo apt-get update +# sudo apt-get install libboost-all-dev -y +# sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y + +cd build +cmake .. +make -j12 + +echo " ======== RUN DEMO ========" + +./client + +echo " ======== END DEMO ========" + +cd .. diff --git a/api/test/client.cpp b/api/test/client.cpp new file mode 100644 index 0000000..8b09e89 --- /dev/null +++ b/api/test/client.cpp @@ -0,0 +1,82 @@ +#include +#include +#include "helper.h" + +int main() { + std::string img_path = "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"; + cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + + if (original_img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + // Convert the image to Base64 + std::string base64_img = image_to_base64(original_img); + + // Convert the Base64 back to an image + cv::Mat decoded_img = base64_to_image(base64_img); + cv::imwrite("decoded_image.jpg", decoded_img); + cv::imwrite("original_img.jpg", original_img); + + // The images should be identical + if (cv::countNonZero(original_img != decoded_img) != 0) { + std::cerr << "The images are not identical" << std::endl; + return -1; + } else { + std::cout << "The images are identical!" << std::endl; + } + + // construct params + APIParams params{ + .data = {base64_img}, + .max_keypoints = {100, 100}, + .timestamps = {"0", "1"}, + .grayscale = {0}, + .image_hw = {{480, 640}, {240, 320}}, + .feature_type = 0, + .rotates = {0.0f, 0.0f}, + .scales = {1.0f, 1.0f}, + .reference_points = { + {1.23e+2f, 1.2e+1f}, + {5.0e-1f, 3.0e-1f}, + {2.3e+2f, 2.2e+1f}, + {6.0e-1f, 4.0e-1f} + }, + .binarize = {1} + }; + + // Convert the parameters to JSON + Json::Value jsonData = paramsToJson(params); + std::string url = "http://127.0.0.1:8001/v1/extract"; + Json::StreamWriterBuilder writer; + std::string output = Json::writeString(writer, jsonData); + + CURL* curl; + CURLcode res; + std::string readBuffer; + + curl_global_init(CURL_GLOBAL_DEFAULT); + curl = curl_easy_init(); + if (curl) { + struct curl_slist* hs = NULL; + hs = curl_slist_append(hs, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + res = curl_easy_perform(curl); + + if (res != CURLE_OK) + fprintf(stderr, "curl_easy_perform() failed: %s\n", + curl_easy_strerror(res)); + else { + // std::cout << "Response from server: " << readBuffer << std::endl; + decode_response(readBuffer); + } + curl_easy_cleanup(curl); + } + curl_global_cleanup(); + + return 0; +} diff --git a/api/test/helper.h b/api/test/helper.h new file mode 100644 index 0000000..e2a228b --- /dev/null +++ b/api/test/helper.h @@ -0,0 +1,362 @@ + +#include +#include +#include +#include +#include +#include + +// base64 to image +#include +#include +#include + +/// Parameters used in the API +struct APIParams { + /// A list of images, base64 encoded + std::vector data; + + /// The maximum number of keypoints to detect for each image + std::vector max_keypoints; + + /// The timestamps of the images + std::vector timestamps; + + /// Whether to convert the images to grayscale + bool grayscale; + + /// The height and width of each image + std::vector> image_hw; + + /// The type of feature detector to use + int feature_type; + + /// The rotations of the images + std::vector rotates; + + /// The scales of the images + std::vector scales; + + /// The reference points of the images + std::vector> reference_points; + + /// Whether to binarize the descriptors + bool binarize; +}; + +/** + * @brief Contains the results of a keypoint detector. + * + * @details Stores the keypoints and descriptors for each image. + */ +class KeyPointResults { +public: + /** + * @brief Constructor. + * + * @param kp The keypoints for each image. + */ + KeyPointResults(const std::vector>& kp): keypoints(kp) {} + + /** + * @brief Append keypoints to the result. + * + * @param kpts The keypoints to append. + */ + inline void append_keypoints(std::vector&kpts) { + keypoints.emplace_back(kpts); + } + /** + * @brief Append descriptors to the result. + * + * @param desc The descriptors to append. + */ + inline void append_descriptors(cv::Mat &desc) { + descriptors.emplace_back(desc); + } + +private: + std::vector> keypoints; + std::vector descriptors; + std::vector> scores; +}; + + +/** + * @brief Decodes a base64 encoded string. + * + * @param base64 The base64 encoded string to decode. + * @return The decoded string. + */ +std::string base64_decode(const std::string& base64) { + using namespace boost::archive::iterators; + using It = transform_width, 8, 6>; + + // Find the position of the last non-whitespace character + auto end = base64.find_last_not_of(" \t\n\r"); + if (end != std::string::npos) { + // Move one past the last non-whitespace character + end += 1; + } + + // Decode the base64 string and return the result + return std::string(It(base64.begin()), It(base64.begin() + end)); +} + + + +/** + * @brief Decodes a base64 string into an OpenCV image + * + * @param base64 The base64 encoded string + * @return The decoded OpenCV image + */ +cv::Mat base64_to_image(const std::string& base64) { + // Decode the base64 string + std::string decodedStr = base64_decode(base64); + + // Decode the image + std::vector data(decodedStr.begin(), decodedStr.end()); + cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE); + + // Check for errors + if (img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + return img; +} + + +/** + * @brief Encodes an OpenCV image into a base64 string + * + * This function takes an OpenCV image and encodes it into a base64 string. + * The image is first encoded as a PNG image, and then the resulting + * bytes are encoded as a base64 string. + * + * @param img The OpenCV image + * @return The base64 encoded string + * + * @throws std::runtime_error if the image is empty or encoding fails + */ +std::string image_to_base64(cv::Mat &img) { + if (img.empty()) { + throw std::runtime_error("Failed to read image"); + } + + // Encode the image as a PNG + std::vector buf; + if (!cv::imencode(".png", img, buf)) { + throw std::runtime_error("Failed to encode image"); + } + + // Encode the bytes as a base64 string + using namespace boost::archive::iterators; + using It = base64_from_binary::const_iterator, 6, 8>>; + std::string base64(It(buf.begin()), It(buf.end())); + + // Pad the string with '=' characters to a multiple of 4 bytes + base64.append((3 - buf.size() % 3) % 3, '='); + + return base64; +} + + +/** + * @brief Callback function for libcurl to write data to a string + * + * This function is used as a callback for libcurl to write data to a string. + * It takes the contents, size, and nmemb as parameters, and writes the data to + * the string. + * + * @param contents The data to write + * @param size The size of the data + * @param nmemb The number of members in the data + * @param s The string to write the data to + * @return The number of bytes written + */ +size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) { + size_t newLength = size * nmemb; + try { + // Resize the string to fit the new data + s->resize(s->size() + newLength); + } catch (std::bad_alloc& e) { + // If there's an error allocating memory, return 0 + return 0; + } + + // Copy the data to the string + std::copy(static_cast(contents), + static_cast(contents) + newLength, + s->begin() + s->size() - newLength); + return newLength; +} + +// Helper functions + +/** + * @brief Helper function to convert a type to a Json::Value + * + * This function takes a value of type T and converts it to a Json::Value. + * It is used to simplify the process of converting a type to a Json::Value. + * + * @param val The value to convert + * @return The converted Json::Value + */ +template +Json::Value toJson(const T& val) { + return Json::Value(val); +} + +/** + * @brief Converts a vector to a Json::Value + * + * This function takes a vector of type T and converts it to a Json::Value. + * Each element in the vector is appended to the Json::Value array. + * + * @param vec The vector to convert to Json::Value + * @return The Json::Value representing the vector + */ +template +Json::Value vectorToJson(const std::vector& vec) { + Json::Value json(Json::arrayValue); + for (const auto& item : vec) { + json.append(item); + } + return json; +} + +/** + * @brief Converts a nested vector to a Json::Value + * + * This function takes a nested vector of type T and converts it to a Json::Value. + * Each sub-vector is converted to a Json::Value array and appended to the main Json::Value array. + * + * @param vec The nested vector to convert to Json::Value + * @return The Json::Value representing the nested vector + */ +template +Json::Value nestedVectorToJson(const std::vector>& vec) { + Json::Value json(Json::arrayValue); + for (const auto& subVec : vec) { + json.append(vectorToJson(subVec)); + } + return json; +} + + + +/** + * @brief Converts the APIParams struct to a Json::Value + * + * This function takes an APIParams struct and converts it to a Json::Value. + * The Json::Value is a JSON object with the following fields: + * - data: a JSON array of base64 encoded images + * - max_keypoints: a JSON array of integers, max number of keypoints for each image + * - timestamps: a JSON array of timestamps, one for each image + * - grayscale: a JSON boolean, whether to convert images to grayscale + * - image_hw: a nested JSON array, each sub-array contains the height and width of an image + * - feature_type: a JSON integer, the type of feature detector to use + * - rotates: a JSON array of doubles, the rotation of each image + * - scales: a JSON array of doubles, the scale of each image + * - reference_points: a nested JSON array, each sub-array contains the reference points of an image + * - binarize: a JSON boolean, whether to binarize the descriptors + * + * @param params The APIParams struct to convert + * @return The Json::Value representing the APIParams struct + */ +Json::Value paramsToJson(const APIParams& params) { + Json::Value json; + json["data"] = vectorToJson(params.data); + json["max_keypoints"] = vectorToJson(params.max_keypoints); + json["timestamps"] = vectorToJson(params.timestamps); + json["grayscale"] = toJson(params.grayscale); + json["image_hw"] = nestedVectorToJson(params.image_hw); + json["feature_type"] = toJson(params.feature_type); + json["rotates"] = vectorToJson(params.rotates); + json["scales"] = vectorToJson(params.scales); + json["reference_points"] = nestedVectorToJson(params.reference_points); + json["binarize"] = toJson(params.binarize); + return json; +} + +template +cv::Mat jsonToMat(Json::Value json) { + int rows = json.size(); + int cols = json[0].size(); + + // Create a single array to hold all the data. + std::vector data; + data.reserve(rows * cols); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + data.push_back(static_cast(json[i][j].asInt())); + } + } + + // Create a cv::Mat object that points to the data. + cv::Mat mat(rows, cols, CV_8UC1, data.data()); // Change the type if necessary. + // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if necessary. + + return mat; +} + + + +/** + * @brief Decodes the response of the server and prints the keypoints + * + * @param response The response of the server + */ +void decode_response(const std::string& response) { + Json::CharReaderBuilder builder; + Json::CharReader* reader = builder.newCharReader(); + + Json::Value jsonData; + std::string errors; + + // Parse the JSON response + bool parsingSuccessful = reader->parse(response.c_str(), + response.c_str() + response.size(), &jsonData, &errors); + delete reader; + + if (!parsingSuccessful) { + // Handle error + std::cout << "Failed to parse the JSON, errors:" << std::endl; + std::cout << errors << std::endl; + return; + } + + // Iterate over the images + for (const auto& jsonItem : jsonData) { + auto jkeypoints = jsonItem["keypoints"]; + auto jdescriptors = jsonItem["descriptors"]; + auto jscores = jsonItem["scores"]; + auto jimageSize = jsonItem["image_size"]; + auto joriginalSize = jsonItem["original_size"]; + auto jsize = jsonItem["size"]; + + std::vector vkeypoints; + std::vector vscores; + + // Iterate over the keypoints + int counter = 0; + for (const auto& keypoint : jkeypoints) { + if (counter < 10) { + // Print the first 10 keypoints + std::cout << keypoint[0].asFloat() << ", " + << keypoint[1].asFloat() << std::endl; + } + counter++; + // Convert the Json::Value to a cv::KeyPoint + vkeypoints.emplace_back(cv::KeyPoint(keypoint[0].asFloat(), + keypoint[1].asFloat(), 0.0)); + } + + // Iterate over the descriptors + cv::Mat descriptors = jsonToMat(jdescriptors); + std::cout << "len keypoints: " << vkeypoints.size() << std::endl; + } +} diff --git a/api/types.py b/api/types.py new file mode 100644 index 0000000..db17dce --- /dev/null +++ b/api/types.py @@ -0,0 +1,16 @@ +from typing import List + +from pydantic import BaseModel + + +class ImagesInput(BaseModel): + data: List[str] = [] + max_keypoints: List[int] = [] + timestamps: List[str] = [] + grayscale: bool = False + image_hw: List[List[int]] = [[], []] + feature_type: int = 0 + rotates: List[float] = [] + scales: List[float] = [] + reference_points: List[List[float]] = [] + binarize: bool = False diff --git a/requirements.txt b/requirements.txt index 2c0a784..d3e0960 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,7 @@ e2cnn einops easydict gdown -gradio==4.44.0 -gradio_client==1.3.0 +gradio==5.1.0 h5py huggingface_hub imageio diff --git a/ui/__init__.py b/ui/__init__.py index e69de29..ac6ccf5 100644 --- a/ui/__init__.py +++ b/ui/__init__.py @@ -0,0 +1,5 @@ +__version__ = "1.0.1" + + +def get_version(): + return __version__ diff --git a/ui/app_class.py b/ui/app_class.py index 94f02f6..628a9a7 100644 --- a/ui/app_class.py +++ b/ui/app_class.py @@ -9,7 +9,6 @@ sys.path.append(str(Path(__file__).parents[1])) -from hloc import flush_logs, read_logs from ui.sfm import SfmEngine from ui.utils import ( GRADIO_VERSION, @@ -275,24 +274,6 @@ def init_interface(self): self.display_supported_algorithms() with gr.Column(): - with gr.Accordion("Open for More: Logs", open=False): - logs = gr.Textbox( - placeholder="\n" * 10, - label="Logs", - info="Verbose from inference will be displayed below.", - lines=10, - max_lines=10, - autoscroll=True, - elem_id="logs", - show_copy_button=True, - container=True, - elem_classes="logs_class", - ) - self.app.load(read_logs, None, logs, every=1) - btn_clear_logs = gr.Button( - "Clear logs", elem_id="logs-button" - ) - btn_clear_logs.click(flush_logs, [], []) with gr.Accordion( "Open for More: Keypoints", open=True @@ -526,7 +507,7 @@ def ui_reset_state( key: str = list(self.matcher_zoo.keys())[ 0 ] # Get the first key from matcher_zoo - flush_logs() + # flush_logs() return ( None, # image0: Optional[np.ndarray] None, # image1: Optional[np.ndarray]