diff --git a/hloc/__init__.py b/hloc/__init__.py index c7db5d9..2845460 100644 --- a/hloc/__init__.py +++ b/hloc/__init__.py @@ -61,3 +61,6 @@ def flush_logs(): ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# model hub: https://huggingface.co/Realcat/imatchui_checkpoint +MODEL_REPO_ID = "Realcat/imatchui_checkpoints" diff --git a/hloc/extractors/alike.py b/hloc/extractors/alike.py index 2f6ae55..4da3ca3 100644 --- a/hloc/extractors/alike.py +++ b/hloc/extractors/alike.py @@ -3,7 +3,7 @@ import torch -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -29,6 +29,14 @@ class Alike(BaseModel): required_inputs = ["image"] def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}.pth".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + logger.info("Loaded Alike model from {}".format(model_path)) + configs[conf["model_name"]]["model_path"] = model_path self.net = Alike_( **configs[conf["model_name"]], device=device, diff --git a/hloc/extractors/d2net.py b/hloc/extractors/d2net.py index 3f92437..98adfd4 100644 --- a/hloc/extractors/d2net.py +++ b/hloc/extractors/d2net.py @@ -1,10 +1,9 @@ -import subprocess import sys from pathlib import Path import torch -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -25,20 +24,17 @@ class D2Net(BaseModel): required_inputs = ["image"] def _init(self, conf): - model_file = conf["checkpoint_dir"] / conf["model_name"] - if not model_file.exists(): - model_file.parent.mkdir(exist_ok=True) - cmd = [ - "wget", - "--quiet", - "https://dusmanu.com/files/d2-net/" + conf["model_name"], - "-O", - str(model_file), - ] - subprocess.run(cmd, check=True) + logger.info("Loading D2Net model...") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + logger.info(f"Loading model from {model_path}...") self.net = _D2Net( - model_file=model_file, use_relu=conf["use_relu"], use_cuda=False + model_file=model_path, use_relu=conf["use_relu"], use_cuda=False ) logger.info("Load D2Net model done.") diff --git a/hloc/extractors/darkfeat.py b/hloc/extractors/darkfeat.py index 41e2b3d..32ad21e 100644 --- a/hloc/extractors/darkfeat.py +++ b/hloc/extractors/darkfeat.py @@ -1,9 +1,7 @@ import sys from pathlib import Path -from huggingface_hub import hf_hub_download - -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -19,20 +17,17 @@ class DarkFeat(BaseModel): "detection_threshold": 0.5, "sub_pixel": False, } - weight_urls = { - "DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t", - } - proxy = "http://localhost:1080" required_inputs = ["image"] def _init(self, conf): - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/DarkFeat/checkpoints/DarkFeat.pth", + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), ) - - self.net = DarkFeat_(cached_file) + logger.info("Loaded DarkFeat model: {}".format(model_path)) + self.net = DarkFeat_(model_path) logger.info("Load DarkFeat model done.") def _forward(self, data): diff --git a/hloc/extractors/dedode.py b/hloc/extractors/dedode.py index a1d7130..d6a228d 100644 --- a/hloc/extractors/dedode.py +++ b/hloc/extractors/dedode.py @@ -1,11 +1,10 @@ -import subprocess import sys from pathlib import Path import torch import torchvision.transforms as transforms -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -30,39 +29,25 @@ class DeDoDe(BaseModel): required_inputs = [ "image", ] - weight_urls = { - "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", - "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth", - } # Initialize the line matcher def _init(self, conf): - model_detector_path = ( - dedode_path / "pretrained" / conf["model_detector_name"] + model_detector_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, conf["model_detector_name"] + ), ) - model_descriptor_path = ( - dedode_path / "pretrained" / conf["model_descriptor_name"] + model_descriptor_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, conf["model_descriptor_name"] + ), ) - + logger.info("Loaded DarkFeat model: {}".format(model_detector_path)) self.normalizer = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - # Download the model. - if not model_detector_path.exists(): - model_detector_path.parent.mkdir(exist_ok=True) - link = self.weight_urls[conf["model_detector_name"]] - cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)] - logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.") - subprocess.run(cmd, check=True) - - if not model_descriptor_path.exists(): - model_descriptor_path.parent.mkdir(exist_ok=True) - link = self.weight_urls[conf["model_descriptor_name"]] - cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)] - logger.info( - f"Downloading the DeDoDe descriptor model with `{cmd}`." - ) - subprocess.run(cmd, check=True) # load the model weights_detector = torch.load(model_detector_path, map_location="cpu") diff --git a/hloc/extractors/dir.py b/hloc/extractors/dir.py index 2d47256..d8a354f 100644 --- a/hloc/extractors/dir.py +++ b/hloc/extractors/dir.py @@ -43,6 +43,7 @@ class DIR(BaseModel): } def _init(self, conf): + # todo: download from google drive -> huggingface models checkpoint = Path( torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt" ) diff --git a/hloc/extractors/lanet.py b/hloc/extractors/lanet.py index 21d908b..7869c40 100644 --- a/hloc/extractors/lanet.py +++ b/hloc/extractors/lanet.py @@ -1,11 +1,9 @@ -import subprocess import sys from pathlib import Path import torch -from huggingface_hub import hf_hub_download -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -20,30 +18,21 @@ class LANet(BaseModel): default_conf = { - "model_name": "v0", + "model_name": "PointModel_v0.pth", "keypoint_threshold": 0.1, "max_keypoints": 1024, } required_inputs = ["image"] def _init(self, conf): - model_path = ( - lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth' - ) - if not model_path.exists(): - logger.warning(f"No model found at {model_path}, start downloading") - model_path.parent.mkdir(exist_ok=True) - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/lanet/checkpoints/PointModel_{}.pth".format( - conf["model_name"] - ), - ) - cmd = ["cp", str(cached_file), str(model_path.parent)] - logger.info(f"copy model file `{cmd}`.") - subprocess.run(cmd, check=True) + logger.info("Loading LANet model...") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) self.net = PointModel(is_test=True) state_dict = torch.load(model_path, map_location="cpu") self.net.load_state_dict(state_dict["model_state"]) diff --git a/hloc/extractors/r2d2.py b/hloc/extractors/r2d2.py index 359d89c..fccb96f 100644 --- a/hloc/extractors/r2d2.py +++ b/hloc/extractors/r2d2.py @@ -3,7 +3,7 @@ import torchvision.transforms as tvf -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -27,11 +27,16 @@ class R2D2(BaseModel): required_inputs = ["image"] def _init(self, conf): - model_fn = r2d2_path / "models" / conf["model_name"] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - self.net = load_network(model_fn) + self.net = load_network(model_path) self.detector = NonMaxSuppression( rel_thr=conf["reliability_threshold"], rep_thr=conf["repetability_threshold"], diff --git a/hloc/extractors/rekd.py b/hloc/extractors/rekd.py index c4fbb5f..0191bce 100644 --- a/hloc/extractors/rekd.py +++ b/hloc/extractors/rekd.py @@ -3,7 +3,7 @@ import torch -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -22,8 +22,12 @@ class REKD(BaseModel): required_inputs = ["image"] def _init(self, conf): - model_path = ( - rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth' + # TODO: download model + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), ) if not model_path.exists(): print(f"No model found at {model_path}") diff --git a/hloc/extractors/rord.py b/hloc/extractors/rord.py index 4d140af..ea0e4ee 100644 --- a/hloc/extractors/rord.py +++ b/hloc/extractors/rord.py @@ -1,11 +1,9 @@ -import subprocess import sys from pathlib import Path import torch -from huggingface_hub import hf_hub_download -from hloc import logger +from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel @@ -24,31 +22,14 @@ class RoRD(BaseModel): "max_keypoints": 1024, } required_inputs = ["image"] - weight_urls = { - "rord.pth": "https://drive.google.com/uc?id=12414ZGKwgPAjNTGtNrlB4VV9l7W76B2o&confirm=t", - } - proxy = "http://localhost:1080" def _init(self, conf): - model_path = conf["checkpoint_dir"] / conf["model_name"] - link = self.weight_urls[conf["model_name"]] # noqa: F841 - - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - cached_file_0 = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/RoRD/models/d2net.pth", - ) - cached_file_1 = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/RoRD/models/rord.pth", - ) - - subprocess.run(["cp", cached_file_0, model_path], check=True) - subprocess.run(["cp", cached_file_1, model_path], check=True) - + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) self.net = _RoRD( model_file=model_path, use_relu=conf["use_relu"], use_cuda=False ) diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 1bd6188..2724ed3 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -3,7 +3,7 @@ import torchvision.transforms as tvf -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel tp_path = Path(__file__).parent / "../../third_party" @@ -24,7 +24,10 @@ def _init(self, conf): self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - model_path = tp_path / "pram" / "weights" / self.conf["model_name"] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format("pram", self.conf["model_name"]), + ) self.net = load_sfd2(weight_path=model_path).eval() logger.info("Load SFD2 model done.") diff --git a/hloc/extractors/superpoint.py b/hloc/extractors/superpoint.py index d47f33a..ee61839 100644 --- a/hloc/extractors/superpoint.py +++ b/hloc/extractors/superpoint.py @@ -48,10 +48,4 @@ def _init(self, conf): logger.info("Load SuperPoint model done.") def _forward(self, data): - pred = self.net(data, self.conf) - pred = { - "keypoints": pred["keypoints"][0][None], - "scores": pred["scores"][0][None], - "descriptors": pred["descriptors"][0][None], - } - return pred + return self.net(data, self.conf) diff --git a/hloc/matchers/aspanformer.py b/hloc/matchers/aspanformer.py index 63ffaa2..1f6bdc6 100644 --- a/hloc/matchers/aspanformer.py +++ b/hloc/matchers/aspanformer.py @@ -1,12 +1,10 @@ -import subprocess import sys from pathlib import Path import torch -from huggingface_hub import hf_hub_download -from .. import logger -from ..utils.base_model import BaseModel +from hloc import MODEL_REPO_ID, logger +from hloc.utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party")) from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer @@ -18,40 +16,15 @@ class ASpanFormer(BaseModel): default_conf = { - "weights": "outdoor", + "model_name": "outdoor.ckpt", "match_threshold": 0.2, "sinkhorn_iterations": 20, "max_keypoints": 2048, "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py", - "model_name": "weights_aspanformer.tar", } required_inputs = ["image0", "image1"] - proxy = "http://localhost:1080" - aspanformer_models = { - "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t" - } def _init(self, conf): - model_path = ( - aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt") - ) - # Download the model. - if not model_path.exists(): - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/ASpanFormer/weights_aspanformer.tar", - ) - cmd = [ - "tar", - "-xvf", - str(cached_file), - "-C", - str(aspanformer_path / "weights"), - ] - logger.info(f"Unzip model file `{cmd}`.") - subprocess.run(cmd, check=True) - config = get_cfg_defaults() config.merge_from_file(conf["config_path"]) _config = lower_config(config) @@ -63,8 +36,13 @@ def _init(self, conf): ] self.net = _ASpanFormer(config=_config["aspan"]) - weight_path = model_path - state_dict = torch.load(str(weight_path), map_location="cpu")[ + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + state_dict = torch.load(str(model_path), map_location="cpu")[ "state_dict" ] self.net.load_state_dict(state_dict, strict=False) diff --git a/hloc/matchers/cotr.py b/hloc/matchers/cotr.py index df28b1b..5986e42 100644 --- a/hloc/matchers/cotr.py +++ b/hloc/matchers/cotr.py @@ -1,19 +1,16 @@ import argparse -import subprocess import sys from pathlib import Path import numpy as np import torch -from huggingface_hub import hf_hub_download from torchvision.transforms import ToPILImage -from .. import logger -from ..utils.base_model import BaseModel +from hloc import DEVICE, MODEL_REPO_ID -cotr_path = Path(__file__).parent / "../../third_party/COTR" +from ..utils.base_model import BaseModel -sys.path.append(str(cotr_path)) +sys.path.append(str(Path(__file__).parent / "../../third_party/COTR")) from COTR.inference.sparse_engine import SparseEngine from COTR.models import build_model from COTR.options.options import * # noqa: F403 @@ -24,14 +21,12 @@ torch.set_grad_enabled(False) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - class COTR(BaseModel): default_conf = { "weights": "out/default", "match_threshold": 0.2, "max_keypoints": -1, + "model_name": "checkpoint.pth.tar", } required_inputs = ["image0", "image1"] @@ -40,27 +35,13 @@ def _init(self, conf): set_COTR_arguments(parser) # noqa: F405 opt = parser.parse_args() opt.command = " ".join(sys.argv) - model_path = cotr_path / conf["weights"] / "checkpoint.pth.tar" - - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True, parents=True) - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/COTR/{}/{}".format( - conf["weights"], "checkpoint.pth.tar" - ), - ) - logger.info("Downloaded COTR model succeeed!") - cmd = [ - "cp", - str(cached_file), - str(model_path.parent), - ] - subprocess.run(cmd, check=True) - logger.info(f"Copy model file `{cmd}`.") + opt.load_weights_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) - opt.load_weights_path = str(model_path) layer_2_channels = { "layer1": 256, "layer2": 512, @@ -70,7 +51,7 @@ def _init(self, conf): opt.dim_feedforward = layer_2_channels[opt.layer] model = build_model(opt) - model = model.to(device) + model = model.to(DEVICE) weights = torch.load(opt.load_weights_path, map_location="cpu")[ "model_state_dict" ] diff --git a/hloc/matchers/dkm.py b/hloc/matchers/dkm.py index f4d702e..a3cae6e 100644 --- a/hloc/matchers/dkm.py +++ b/hloc/matchers/dkm.py @@ -1,48 +1,35 @@ -import subprocess import sys from pathlib import Path -import torch from PIL import Image -from .. import logger -from ..utils.base_model import BaseModel +from hloc import DEVICE, MODEL_REPO_ID, logger +from hloc.utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party")) from DKM.dkm import DKMv3_outdoor -dkm_path = Path(__file__).parent / "../../third_party/DKM" -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class DKMv3(BaseModel): default_conf = { "model_name": "DKMv3_outdoor.pth", "match_threshold": 0.2, - "checkpoint_dir": dkm_path / "pretrained", "max_keypoints": -1, } required_inputs = [ "image0", "image1", ] - # Models exported using - dkm_models = { - "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", - "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", - } def _init(self, conf): - model_path = dkm_path / "pretrained" / conf["model_name"] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - link = self.dkm_models[conf["model_name"]] - cmd = ["wget", "--quiet", link, "-O", str(model_path)] - logger.info(f"Downloading the DKMv3 model with `{cmd}`.") - subprocess.run(cmd, check=True) - self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device) + self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=DEVICE) logger.info("Loading DKMv3 model done") def _forward(self, data): @@ -55,7 +42,7 @@ def _forward(self, data): W_A, H_A = img0.size W_B, H_B = img1.size - warp, certainty = self.net.match(img0, img1, device=device) + warp, certainty = self.net.match(img0, img1, device=DEVICE) matches, certainty = self.net.sample( warp, certainty, num=self.conf["max_keypoints"] ) diff --git a/hloc/matchers/duster.py b/hloc/matchers/duster.py index 2243d8a..14c30c6 100644 --- a/hloc/matchers/duster.py +++ b/hloc/matchers/duster.py @@ -1,13 +1,11 @@ -import os import sys -import urllib.request from pathlib import Path import numpy as np import torch import torchvision.transforms as tfm -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel duster_path = Path(__file__).parent / "../../third_party/dust3r" @@ -25,30 +23,24 @@ class Duster(BaseModel): default_conf = { "name": "Duster3r", - "model_path": duster_path / "model_weights/duster_vit_large.pth", + "model_name": "duster_vit_large.pth", "max_keypoints": 3000, "vit_patch_size": 16, } def _init(self, conf): self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - self.model_path = self.conf["model_path"] - self.download_weights() - # self.net = load_model(self.model_path, device) - self.net = AsymmetricCroCo3DStereo.from_pretrained( - self.model_path - # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" - ).to(device) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + self.net = AsymmetricCroCo3DStereo.from_pretrained(model_path).to( + device + ) logger.info("Loaded Dust3r model") - def download_weights(self): - url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" - - self.model_path.parent.mkdir(parents=True, exist_ok=True) - if not os.path.isfile(self.model_path): - logger.info("Downloading Duster(ViT large)... (takes a while)") - urllib.request.urlretrieve(url, self.model_path) - def preprocess(self, img): # the super-class already makes sure that img0,img1 have # same resolution and that h == w diff --git a/hloc/matchers/eloftr.py b/hloc/matchers/eloftr.py index 4e805a4..b95d840 100644 --- a/hloc/matchers/eloftr.py +++ b/hloc/matchers/eloftr.py @@ -1,11 +1,11 @@ -import subprocess import sys import warnings from copy import deepcopy from pathlib import Path import torch -from huggingface_hub import hf_hub_download + +from hloc import MODEL_REPO_ID tp_path = Path(__file__).parent / "../../third_party" sys.path.append(str(tp_path)) @@ -24,7 +24,7 @@ class ELoFTR(BaseModel): default_conf = { - "weights": "weights/eloftr_outdoor.ckpt", + "model_name": "eloftr_outdoor.ckpt", "match_threshold": 0.2, # "sinkhorn_iterations": 20, "max_keypoints": -1, @@ -47,26 +47,12 @@ def _init(self, conf): elif self.conf["precision"] == "fp16": _default_cfg["half"] = True - model_path = tp_path / "EfficientLoFTR" / self.conf["weights"] - - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/EfficientLoFTR/{}".format( - conf["weights"] - ), - ) - logger.info("Downloaded EfficientLoFTR model succeeed!") - cmd = [ - "cp", - str(cached_file), - str(model_path), - ] - subprocess.run(cmd, check=True) - logger.info(f"Copy model file `{cmd}`.") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) cfg = _default_cfg cfg["match_coarse"]["thr"] = conf["match_threshold"] @@ -78,7 +64,7 @@ def _init(self, conf): if self.conf["precision"] == "fp16": self.net = self.net.half() - logger.info(f"Loaded Efficient LoFTR with weights {conf['weights']}") + logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}") def _forward(self, data): # For consistency with hloc pairs, we refine kpts in image0! diff --git a/hloc/matchers/gim.py b/hloc/matchers/gim.py index 0572bbe..1bd98fd 100644 --- a/hloc/matchers/gim.py +++ b/hloc/matchers/gim.py @@ -1,74 +1,120 @@ -import subprocess import sys from pathlib import Path import torch -from huggingface_hub import hf_hub_download -from .. import logger +from .. import DEVICE, MODEL_REPO_ID, logger from ..utils.base_model import BaseModel gim_path = Path(__file__).parent / "../../third_party/gim" sys.path.append(str(gim_path)) -from dkm.models.model_zoo.DKMv3 import DKMv3 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def load_model(weight_name, checkpoints_path): + # load model + model = None + detector = None + if weight_name == "gim_dkm": + from dkm.models.model_zoo.DKMv3 import DKMv3 + + model = DKMv3(weights=None, h=672, w=896) + elif weight_name == "gim_loftr": + from loftr.config import get_cfg_defaults + from loftr.loftr import LoFTR + from loftr.misc import lower_config + + model = LoFTR(lower_config(get_cfg_defaults())["loftr"]) + elif weight_name == "gim_lightglue": + from lightglue.models.matchers.lightglue import LightGlue + from lightglue.superpoint import SuperPoint + + detector = SuperPoint( + { + "max_num_keypoints": 2048, + "force_num_keypoints": True, + "detection_threshold": 0.0, + "nms_radius": 3, + "trainable": False, + } + ) + model = LightGlue( + { + "filter_threshold": 0.1, + "flash": False, + "checkpointed": True, + } + ) + + # load state dict + if weight_name == "gim_dkm": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + if "encoder.net.fc" in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + elif weight_name == "gim_loftr": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict) + + elif weight_name == "gim_lightglue": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict.pop(k) + if k.startswith("superpoint."): + state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k) + detector.load_state_dict(state_dict) + + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("superpoint."): + state_dict.pop(k) + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + model.load_state_dict(state_dict) + + # eval mode + if detector is not None: + detector = detector.eval().to(DEVICE) + model = model.eval().to(DEVICE) + return model class GIM(BaseModel): default_conf = { - "model_name": "gim_dkm_100h.ckpt", "match_threshold": 0.2, "checkpoint_dir": gim_path / "weights", + "weights": "gim_dkm", } required_inputs = [ "image0", "image1", ] - model_list = ["gim_lightglue_100h.ckpt", "gim_dkm_100h.ckpt"] - model_dict = { - "gim_lightglue_100h.ckpt": "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt", - "gim_dkm_100h.ckpt": "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view", + ckpt_name_dict = { + "gim_dkm": "gim_dkm_100h.ckpt", + "gim_loftr": "gim_loftr_50h.ckpt", + "gim_lightglue": "gim_lightglue_100h.ckpt", } def _init(self, conf): - conf["model_name"] = str(conf["weights"]) - if conf["model_name"] not in self.model_list: - raise ValueError(f"Unknown GIM model {conf['model_name']}.") - model_path = conf["checkpoint_dir"] / conf["model_name"] - - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/gim/weights/{}".format( - conf["model_name"] - ), - ) - logger.info("Downloaded GIM model succeeed!") - cmd = [ - "cp", - str(cached_file), - str(conf["checkpoint_dir"]), - ] - subprocess.run(cmd, check=True) - logger.info(f"Copy model file `{cmd}`.") - + ckpt_name = self.ckpt_name_dict[conf["weights"]] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, ckpt_name), + ) self.aspect_ratio = 896 / 672 - model = DKMv3(None, 672, 896, upsample_preds=True) - state_dict = torch.load(str(model_path), map_location="cpu") - if "state_dict" in state_dict.keys(): - state_dict = state_dict["state_dict"] - for k in list(state_dict.keys()): - if k.startswith("model."): - state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) - if "encoder.net.fc" in k: - state_dict.pop(k) - model.load_state_dict(state_dict) - + model = load_model(conf["weights"], model_path) self.net = model logger.info("Loaded GIM model") @@ -120,6 +166,7 @@ def compute_mask(self, kpts0, kpts1, orig_shape0, orig_shape1): return mask def _forward(self, data): + # TODO: only support dkm+gim image0, image1 = self.pad_image( data["image0"], self.aspect_ratio ), self.pad_image(data["image1"], self.aspect_ratio) diff --git a/hloc/matchers/gluestick.py b/hloc/matchers/gluestick.py index b14614e..fea550a 100644 --- a/hloc/matchers/gluestick.py +++ b/hloc/matchers/gluestick.py @@ -1,10 +1,9 @@ -import subprocess import sys from pathlib import Path import torch -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel gluestick_path = Path(__file__).parent / "../../third_party/GlueStick" @@ -13,8 +12,6 @@ from gluestick import batch_to_np from gluestick.models.two_view_pipeline import TwoViewPipeline -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class GlueStick(BaseModel): default_conf = { @@ -30,23 +27,15 @@ class GlueStick(BaseModel): "image1", ] - gluestick_models = { - "checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar", - } - # Initialize the line matcher def _init(self, conf): - model_path = ( - gluestick_path / "resources" / "weights" / conf["model_name"] - ) - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - link = self.gluestick_models[conf["model_name"]] - cmd = ["wget", "--quiet", link, "-O", str(model_path)] - logger.info(f"Downloading the Gluestick model with `{cmd}`.") - subprocess.run(cmd, check=True) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) logger.info("Loading GlueStick model...") gluestick_conf = { diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index 05c3cb9..f37d218 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -3,7 +3,7 @@ import torch -from .. import DEVICE, logger +from .. import DEVICE, MODEL_REPO_ID, logger from ..utils.base_model import BaseModel tp_path = Path(__file__).parent / "../../third_party" @@ -31,11 +31,15 @@ class IMP(BaseModel): def _init(self, conf): self.conf = {**self.default_conf, **conf} - weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format("pram", self.conf["model_name"]), + ) + # self.net = nets.gml(self.conf).eval().to(DEVICE) self.net = GML(self.conf).eval().to(DEVICE) self.net.load_state_dict( - torch.load(weight_path, map_location="cpu")["model"], strict=True + torch.load(model_path, map_location="cpu")["model"], strict=True ) logger.info("Load IMP model done.") diff --git a/hloc/matchers/lightglue.py b/hloc/matchers/lightglue.py index 4a36be6..975b554 100644 --- a/hloc/matchers/lightglue.py +++ b/hloc/matchers/lightglue.py @@ -1,7 +1,7 @@ import sys from pathlib import Path -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" @@ -33,8 +33,14 @@ class LightGlue(BaseModel): ] def _init(self, conf): - weight_path = lightglue_path / "weights" / conf["model_name"] - conf["weights"] = str(weight_path) + logger.info("Loading lightglue model, {}".format(conf["model_name"])) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + conf["weights"] = str(model_path) conf["filter_threshold"] = conf["match_threshold"] self.net = LG(**conf) logger.info("Load lightglue model done.") diff --git a/hloc/matchers/mast3r.py b/hloc/matchers/mast3r.py index 46489cc..75d016b 100644 --- a/hloc/matchers/mast3r.py +++ b/hloc/matchers/mast3r.py @@ -1,13 +1,11 @@ -import os import sys -import urllib.request from pathlib import Path import numpy as np import torch import torchvision.transforms as tfm -from .. import logger +from .. import DEVICE, MODEL_REPO_ID, logger mast3r_path = Path(__file__).parent / "../../third_party/mast3r" sys.path.append(str(mast3r_path)) @@ -22,38 +20,30 @@ from hloc.matchers.duster import Duster -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class Mast3r(Duster): default_conf = { "name": "Mast3r", - "model_path": mast3r_path - / "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", + "model_name": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", "max_keypoints": 2000, "vit_patch_size": 16, } def _init(self, conf): self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - self.model_path = self.conf["model_path"] - self.download_weights() - self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + self.net = AsymmetricMASt3R.from_pretrained(model_path).to(DEVICE) logger.info("Loaded Mast3r model") - def download_weights(self): - url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" - - self.model_path.parent.mkdir(parents=True, exist_ok=True) - if not os.path.isfile(self.model_path): - logger.info("Downloading Mast3r(ViT large)... (takes a while)") - urllib.request.urlretrieve(url, self.model_path) - logger.info("Downloading Mast3r(ViT large)... done!") - def _forward(self, data): img0, img1 = data["image0"], data["image1"] - mean = torch.tensor([0.5, 0.5, 0.5]).to(device) - std = torch.tensor([0.5, 0.5, 0.5]).to(device) + mean = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE) + std = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE) img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) @@ -65,7 +55,7 @@ def _forward(self, data): pairs = make_pairs( images, scene_graph="complete", prefilter=None, symmetrize=True ) - output = inference(pairs, self.net, device, batch_size=1) + output = inference(pairs, self.net, DEVICE, batch_size=1) # at this stage, you have the raw dust3r predictions _, pred1 = output["view1"], output["pred1"] @@ -81,7 +71,7 @@ def _forward(self, data): desc1, desc2, subsample_or_initxy1=2, - device=device, + device=DEVICE, dist="dot", block_size=2**13, ) diff --git a/hloc/matchers/mickey.py b/hloc/matchers/mickey.py index 3d60ff5..57a8ab0 100644 --- a/hloc/matchers/mickey.py +++ b/hloc/matchers/mickey.py @@ -1,10 +1,7 @@ -import subprocess import sys from pathlib import Path -import torch - -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel mickey_path = Path(__file__).parent / "../../third_party" @@ -13,8 +10,6 @@ from mickey.config.default import cfg from mickey.lib.models.builder import build_model -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class Mickey(BaseModel): default_conf = { @@ -26,33 +21,23 @@ class Mickey(BaseModel): "image0", "image1", ] - weight_urls = "https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip" # Initialize the line matcher def _init(self, conf): - model_path = mickey_path / "mickey/mickey_weights" / conf["model_name"] - zip_path = mickey_path / "mickey/mickey_weights.zip" + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + # TODO: config path of mickey config_path = model_path.parent / self.conf["config_path"] - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True, parents=True) - link = self.weight_urls - if not zip_path.exists(): - cmd = ["wget", "--quiet", link, "-O", str(zip_path)] - logger.info(f"Downloading the Mickey model with {cmd}.") - subprocess.run(cmd, check=True) - cmd = ["unzip", "-d", str(model_path.parent.parent), str(zip_path)] - logger.info(f"Running {cmd}.") - subprocess.run(cmd, check=True) - logger.info("Loading mickey model...") cfg.merge_from_file(config_path) self.net = build_model(cfg, checkpoint=model_path) logger.info("Load Mickey model done.") def _forward(self, data): - # data['K_color0'] = torch.from_numpy(K['im0.jpg']).unsqueeze(0).to(device) - # data['K_color1'] = torch.from_numpy(K['im1.jpg']).unsqueeze(0).to(device) pred = self.net(data) pred = { **pred, diff --git a/hloc/matchers/omniglue.py b/hloc/matchers/omniglue.py index c02d7f3..0c709de 100644 --- a/hloc/matchers/omniglue.py +++ b/hloc/matchers/omniglue.py @@ -1,11 +1,10 @@ -import subprocess import sys from pathlib import Path import numpy as np import torch -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel thirdparty_path = Path(__file__).parent / "../../third_party" @@ -27,19 +26,21 @@ class OmniGlue(BaseModel): def _init(self, conf): logger.info("Loading OmniGlue model") - og_model_path = omniglue_path / "models" / "omniglue.onnx" - sp_model_path = omniglue_path / "models" / "sp_v6.onnx" - dino_model_path = ( - omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB + og_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, "omniglue.onnx"), ) - if not dino_model_path.exists(): - link = self.dino_v2_link_dict.get(dino_model_path.name, None) - if link is not None: - cmd = ["wget", "--quiet", link, "-O", str(dino_model_path)] - logger.info(f"Downloading the dinov2 model with `{cmd}`.") - subprocess.run(cmd, check=True) - else: - logger.error(f"Invalid dinov2 model: {dino_model_path.name}") + sp_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, "sp_v6.onnx"), + ) + dino_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, "dinov2_vitb14_pretrain.pth" + ), + ) + self.net = omniglue.OmniGlue( og_export=str(og_model_path), sp_export=str(sp_model_path), diff --git a/hloc/matchers/roma.py b/hloc/matchers/roma.py index 0194916..2187373 100644 --- a/hloc/matchers/roma.py +++ b/hloc/matchers/roma.py @@ -1,11 +1,10 @@ -import subprocess import sys from pathlib import Path import torch from PIL import Image -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel roma_path = Path(__file__).parent / "../../third_party/RoMa" @@ -26,33 +25,22 @@ class Roma(BaseModel): "image0", "image1", ] - weight_urls = { - "roma": { - "roma_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", - "roma_indoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", - }, - "dinov2_vitl14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", - } # Initialize the line matcher def _init(self, conf): - model_path = roma_path / "pretrained" / conf["model_name"] - dinov2_weights = roma_path / "pretrained" / conf["model_utils_name"] - - # Download the model. - if not model_path.exists(): - model_path.parent.mkdir(exist_ok=True) - link = self.weight_urls["roma"][conf["model_name"]] - cmd = ["wget", "--quiet", link, "-O", str(model_path)] - logger.info(f"Downloading the Roma model with `{cmd}`.") - subprocess.run(cmd, check=True) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) - if not dinov2_weights.exists(): - dinov2_weights.parent.mkdir(exist_ok=True) - link = self.weight_urls[conf["model_utils_name"]] - cmd = ["wget", "--quiet", link, "-O", str(dinov2_weights)] - logger.info(f"Downloading the dinov2 model with `{cmd}`.") - subprocess.run(cmd, check=True) + dinov2_weights = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_utils_name"] + ), + ) logger.info("Loading Roma model") # load the model diff --git a/hloc/matchers/sgmnet.py b/hloc/matchers/sgmnet.py index 3b964cf..7aeb219 100644 --- a/hloc/matchers/sgmnet.py +++ b/hloc/matchers/sgmnet.py @@ -1,12 +1,10 @@ -import subprocess import sys from collections import OrderedDict, namedtuple from pathlib import Path import torch -from huggingface_hub import hf_hub_download -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet" @@ -20,7 +18,7 @@ class SGMNet(BaseModel): default_conf = { "name": "SGM", - "model_name": "model_best.pth", + "model_name": "weights/sgm/root/model_best.pth", "seed_top_k": [256, 256], "seed_radius_coe": 0.01, "net_channels": 128, @@ -38,30 +36,20 @@ class SGMNet(BaseModel): "image0", "image1", ] - weight_urls = { - "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t", - } - proxy = "http://localhost:1080" # Initialize the line matcher def _init(self, conf): - sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"] - - # Download the model. - if not sgmnet_weights.exists(): - cached_file = hf_hub_download( - repo_type="space", - repo_id="Realcat/image-matching-webui", - filename="third_party/SGMNet/weights.tar.gz", - ) - cmd = ["tar", "-xvf", str(cached_file), "-C", str(sgmnet_path)] - logger.info(f"Unzip model file `{cmd}`.") - subprocess.run(cmd, check=True) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) # config config = namedtuple("config", conf.keys())(*conf.values()) self.net = SGM_Model(config) - checkpoint = torch.load(sgmnet_weights, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # for ddp model if ( list(checkpoint["state_dict"].items())[0][0].split(".")[0] diff --git a/hloc/matchers/sold2.py b/hloc/matchers/sold2.py index e7ac07f..4cbc637 100644 --- a/hloc/matchers/sold2.py +++ b/hloc/matchers/sold2.py @@ -1,10 +1,9 @@ -import subprocess import sys from pathlib import Path import torch -from .. import logger +from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel sold2_path = Path(__file__).parent / "../../third_party/SOLD2" @@ -17,7 +16,7 @@ class SOLD2(BaseModel): default_conf = { - "weights": "sold2_wireframe.tar", + "model_name": "sold2_wireframe.tar", "match_threshold": 0.2, "checkpoint_dir": sold2_path / "pretrained", "detect_thresh": 0.25, @@ -31,21 +30,15 @@ class SOLD2(BaseModel): "image1", ] - weight_urls = { - "sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download", - } - # Initialize the line matcher def _init(self, conf): - checkpoint_path = conf["checkpoint_dir"] / conf["weights"] - - # Download the model. - if not checkpoint_path.exists(): - checkpoint_path.parent.mkdir(exist_ok=True) - link = self.weight_urls[conf["weights"]] - cmd = ["wget", "--quiet", link, "-O", str(checkpoint_path)] - logger.info(f"Downloading the SOLD2 model with `{cmd}`.") - subprocess.run(cmd, check=True) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + logger.info("Loading SOLD2 model: {}".format(model_path)) mode = "dynamic" # 'dynamic' or 'static' match_config = { @@ -127,7 +120,7 @@ def _init(self, conf): } self.net = LineMatcher( match_config["model_cfg"], - checkpoint_path, + model_path, device, match_config["line_detector_cfg"], match_config["line_matcher_cfg"], diff --git a/hloc/matchers/topicfm.py b/hloc/matchers/topicfm.py index 2d4701c..544bf2c 100644 --- a/hloc/matchers/topicfm.py +++ b/hloc/matchers/topicfm.py @@ -3,6 +3,8 @@ import torch +from hloc import MODEL_REPO_ID + from ..utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party")) @@ -15,6 +17,7 @@ class TopicFM(BaseModel): default_conf = { "weights": "outdoor", + "model_name": "model_best.ckpt", "match_threshold": 0.2, "n_sampling_topics": 4, "max_keypoints": -1, @@ -25,9 +28,14 @@ def _init(self, conf): _conf = dict(get_model_cfg()) _conf["match_coarse"]["thr"] = conf["match_threshold"] _conf["coarse"]["n_samples"] = conf["n_sampling_topics"] - weight_path = topicfm_path / "pretrained/model_best.ckpt" + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) self.net = _TopicFM(config=_conf) - ckpt_dict = torch.load(weight_path, map_location="cpu") + ckpt_dict = torch.load(model_path, map_location="cpu") self.net.load_state_dict(ckpt_dict["state_dict"]) def _forward(self, data): diff --git a/hloc/utils/base_model.py b/hloc/utils/base_model.py index f560a26..bd461e6 100644 --- a/hloc/utils/base_model.py +++ b/hloc/utils/base_model.py @@ -3,6 +3,7 @@ from torch import nn from copy import copy import inspect +from huggingface_hub import hf_hub_download class BaseModel(nn.Module, metaclass=ABCMeta): @@ -32,7 +33,14 @@ def _init(self, conf): def _forward(self, data): """To be implemented by the child class.""" raise NotImplementedError - + + def _download_model(self, repo_id=None, filename=None, **kwargs): + """Download model from hf hub and return the path.""" + return hf_hub_download( + repo_type="model", + repo_id=repo_id, + filename=filename, + ) def dynamic_load(root, model): module_path = f"{root.__name__}.{model}"