diff --git a/hloc/matchers/cotr.py b/hloc/matchers/cotr.py index 44d74f6..df28b1b 100644 --- a/hloc/matchers/cotr.py +++ b/hloc/matchers/cotr.py @@ -1,14 +1,19 @@ import argparse +import subprocess import sys from pathlib import Path import numpy as np import torch +from huggingface_hub import hf_hub_download from torchvision.transforms import ToPILImage +from .. import logger from ..utils.base_model import BaseModel -sys.path.append(str(Path(__file__).parent / "../../third_party/COTR")) +cotr_path = Path(__file__).parent / "../../third_party/COTR" + +sys.path.append(str(cotr_path)) from COTR.inference.sparse_engine import SparseEngine from COTR.models import build_model from COTR.options.options import * # noqa: F403 @@ -18,7 +23,6 @@ utils_cotr.fix_randomness(0) torch.set_grad_enabled(False) -cotr_path = Path(__file__).parent / "../../third_party/COTR" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -36,10 +40,27 @@ def _init(self, conf): set_COTR_arguments(parser) # noqa: F405 opt = parser.parse_args() opt.command = " ".join(sys.argv) - opt.load_weights_path = str( - cotr_path / conf["weights"] / "checkpoint.pth.tar" - ) + model_path = cotr_path / conf["weights"] / "checkpoint.pth.tar" + + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True, parents=True) + cached_file = hf_hub_download( + repo_type="space", + repo_id="Realcat/image-matching-webui", + filename="third_party/COTR/{}/{}".format( + conf["weights"], "checkpoint.pth.tar" + ), + ) + logger.info("Downloaded COTR model succeeed!") + cmd = [ + "cp", + str(cached_file), + str(model_path.parent), + ] + subprocess.run(cmd, check=True) + logger.info(f"Copy model file `{cmd}`.") + opt.load_weights_path = str(model_path) layer_2_channels = { "layer1": 256, "layer2": 512, diff --git a/hloc/matchers/gim.py b/hloc/matchers/gim.py index d0ccffa..0572bbe 100644 --- a/hloc/matchers/gim.py +++ b/hloc/matchers/gim.py @@ -2,8 +2,8 @@ import sys from pathlib import Path -import gdown import torch +from huggingface_hub import hf_hub_download from .. import logger from ..utils.base_model import BaseModel @@ -26,6 +26,7 @@ class GIM(BaseModel): "image0", "image1", ] + model_list = ["gim_lightglue_100h.ckpt", "gim_dkm_100h.ckpt"] model_dict = { "gim_lightglue_100h.ckpt": "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt", "gim_dkm_100h.ckpt": "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view", @@ -33,20 +34,28 @@ class GIM(BaseModel): def _init(self, conf): conf["model_name"] = str(conf["weights"]) - if conf["model_name"] not in self.model_dict: + if conf["model_name"] not in self.model_list: raise ValueError(f"Unknown GIM model {conf['model_name']}.") model_path = conf["checkpoint_dir"] / conf["model_name"] # Download the model. if not model_path.exists(): model_path.parent.mkdir(exist_ok=True) - model_link = self.model_dict[conf["model_name"]] - if "drive.google.com" in model_link: - gdown.download(model_link, output=str(model_path), fuzzy=True) - else: - cmd = ["wget", "--quiet", model_link, "-O", str(model_path)] - subprocess.run(cmd, check=True) + cached_file = hf_hub_download( + repo_type="space", + repo_id="Realcat/image-matching-webui", + filename="third_party/gim/weights/{}".format( + conf["model_name"] + ), + ) logger.info("Downloaded GIM model succeeed!") + cmd = [ + "cp", + str(cached_file), + str(conf["checkpoint_dir"]), + ] + subprocess.run(cmd, check=True) + logger.info(f"Copy model file `{cmd}`.") self.aspect_ratio = 896 / 672 model = DKMv3(None, 672, 896, upsample_preds=True)