Skip to content

Commit

Permalink
update: format code
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincentqyw committed Oct 29, 2024
1 parent a8ab6a7 commit 845af8b
Show file tree
Hide file tree
Showing 29 changed files with 334 additions and 439 deletions.
3 changes: 3 additions & 0 deletions hloc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def flush_logs():
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model hub: https://huggingface.co/Realcat/imatchui_checkpoint
MODEL_REPO_ID = "Realcat/imatchui_checkpoints"
10 changes: 9 additions & 1 deletion hloc/extractors/alike.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -29,6 +29,14 @@ class Alike(BaseModel):
required_inputs = ["image"]

def _init(self, conf):
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}.pth".format(
Path(__file__).stem, self.conf["model_name"]
),
)
logger.info("Loaded Alike model from {}".format(model_path))
configs[conf["model_name"]]["model_path"] = model_path
self.net = Alike_(
**configs[conf["model_name"]],
device=device,
Expand Down
24 changes: 10 additions & 14 deletions hloc/extractors/d2net.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import subprocess
import sys
from pathlib import Path

import torch

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -25,20 +24,17 @@ class D2Net(BaseModel):
required_inputs = ["image"]

def _init(self, conf):
model_file = conf["checkpoint_dir"] / conf["model_name"]
if not model_file.exists():
model_file.parent.mkdir(exist_ok=True)
cmd = [
"wget",
"--quiet",
"https://dusmanu.com/files/d2-net/" + conf["model_name"],
"-O",
str(model_file),
]
subprocess.run(cmd, check=True)

logger.info("Loading D2Net model...")
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
logger.info(f"Loading model from {model_path}...")
self.net = _D2Net(
model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
)
logger.info("Load D2Net model done.")

Expand Down
21 changes: 8 additions & 13 deletions hloc/extractors/darkfeat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import sys
from pathlib import Path

from huggingface_hub import hf_hub_download

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -19,20 +17,17 @@ class DarkFeat(BaseModel):
"detection_threshold": 0.5,
"sub_pixel": False,
}
weight_urls = {
"DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t",
}
proxy = "http://localhost:1080"
required_inputs = ["image"]

def _init(self, conf):
cached_file = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/DarkFeat/checkpoints/DarkFeat.pth",
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)

self.net = DarkFeat_(cached_file)
logger.info("Loaded DarkFeat model: {}".format(model_path))
self.net = DarkFeat_(model_path)
logger.info("Load DarkFeat model done.")

def _forward(self, data):
Expand Down
39 changes: 12 additions & 27 deletions hloc/extractors/dedode.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import subprocess
import sys
from pathlib import Path

import torch
import torchvision.transforms as transforms

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -30,39 +29,25 @@ class DeDoDe(BaseModel):
required_inputs = [
"image",
]
weight_urls = {
"dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
"dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
}

# Initialize the line matcher
def _init(self, conf):
model_detector_path = (
dedode_path / "pretrained" / conf["model_detector_name"]
model_detector_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, conf["model_detector_name"]
),
)
model_descriptor_path = (
dedode_path / "pretrained" / conf["model_descriptor_name"]
model_descriptor_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, conf["model_descriptor_name"]
),
)

logger.info("Loaded DarkFeat model: {}".format(model_detector_path))
self.normalizer = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
# Download the model.
if not model_detector_path.exists():
model_detector_path.parent.mkdir(exist_ok=True)
link = self.weight_urls[conf["model_detector_name"]]
cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
subprocess.run(cmd, check=True)

if not model_descriptor_path.exists():
model_descriptor_path.parent.mkdir(exist_ok=True)
link = self.weight_urls[conf["model_descriptor_name"]]
cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
logger.info(
f"Downloading the DeDoDe descriptor model with `{cmd}`."
)
subprocess.run(cmd, check=True)

# load the model
weights_detector = torch.load(model_detector_path, map_location="cpu")
Expand Down
1 change: 1 addition & 0 deletions hloc/extractors/dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DIR(BaseModel):
}

def _init(self, conf):
# todo: download from google drive -> huggingface models
checkpoint = Path(
torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt"
)
Expand Down
29 changes: 9 additions & 20 deletions hloc/extractors/lanet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import subprocess
import sys
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -20,30 +18,21 @@

class LANet(BaseModel):
default_conf = {
"model_name": "v0",
"model_name": "PointModel_v0.pth",
"keypoint_threshold": 0.1,
"max_keypoints": 1024,
}
required_inputs = ["image"]

def _init(self, conf):
model_path = (
lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
)
if not model_path.exists():
logger.warning(f"No model found at {model_path}, start downloading")
model_path.parent.mkdir(exist_ok=True)
cached_file = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/lanet/checkpoints/PointModel_{}.pth".format(
conf["model_name"]
),
)
cmd = ["cp", str(cached_file), str(model_path.parent)]
logger.info(f"copy model file `{cmd}`.")
subprocess.run(cmd, check=True)
logger.info("Loading LANet model...")

model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
self.net = PointModel(is_test=True)
state_dict = torch.load(model_path, map_location="cpu")
self.net.load_state_dict(state_dict["model_state"])
Expand Down
11 changes: 8 additions & 3 deletions hloc/extractors/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torchvision.transforms as tvf

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -27,11 +27,16 @@ class R2D2(BaseModel):
required_inputs = ["image"]

def _init(self, conf):
model_fn = r2d2_path / "models" / conf["model_name"]
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
self.net = load_network(model_fn)
self.net = load_network(model_path)
self.detector = NonMaxSuppression(
rel_thr=conf["reliability_threshold"],
rep_thr=conf["repetability_threshold"],
Expand Down
10 changes: 7 additions & 3 deletions hloc/extractors/rekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -22,8 +22,12 @@ class REKD(BaseModel):
required_inputs = ["image"]

def _init(self, conf):
model_path = (
rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
# TODO: download model
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
if not model_path.exists():
print(f"No model found at {model_path}")
Expand Down
33 changes: 7 additions & 26 deletions hloc/extractors/rord.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import subprocess
import sys
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download

from hloc import logger
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

Expand All @@ -24,31 +22,14 @@ class RoRD(BaseModel):
"max_keypoints": 1024,
}
required_inputs = ["image"]
weight_urls = {
"rord.pth": "https://drive.google.com/uc?id=12414ZGKwgPAjNTGtNrlB4VV9l7W76B2o&confirm=t",
}
proxy = "http://localhost:1080"

def _init(self, conf):
model_path = conf["checkpoint_dir"] / conf["model_name"]
link = self.weight_urls[conf["model_name"]] # noqa: F841

if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
cached_file_0 = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/RoRD/models/d2net.pth",
)
cached_file_1 = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/RoRD/models/rord.pth",
)

subprocess.run(["cp", cached_file_0, model_path], check=True)
subprocess.run(["cp", cached_file_1, model_path], check=True)

model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
self.net = _RoRD(
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
)
Expand Down
7 changes: 5 additions & 2 deletions hloc/extractors/sfd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torchvision.transforms as tvf

from .. import logger
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

tp_path = Path(__file__).parent / "../../third_party"
Expand All @@ -24,7 +24,10 @@ def _init(self, conf):
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
model_path = tp_path / "pram" / "weights" / self.conf["model_name"]
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format("pram", self.conf["model_name"]),
)
self.net = load_sfd2(weight_path=model_path).eval()

logger.info("Load SFD2 model done.")
Expand Down
8 changes: 1 addition & 7 deletions hloc/extractors/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,4 @@ def _init(self, conf):
logger.info("Load SuperPoint model done.")

def _forward(self, data):
pred = self.net(data, self.conf)
pred = {
"keypoints": pred["keypoints"][0][None],
"scores": pred["scores"][0][None],
"descriptors": pred["descriptors"][0][None],
}
return pred
return self.net(data, self.conf)
Loading

0 comments on commit 845af8b

Please sign in to comment.