Skip to content

Commit

Permalink
- Changed DeDoDe and LightGlue to kornia implementation
Browse files Browse the repository at this point in the history
- Changed DeDoDe to G variant
- Added DeDoDe v2
- Added LightGlue support for DeDoDe
  • Loading branch information
Dawars committed Nov 23, 2024
1 parent 05d1ba6 commit 4e8978b
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 81 deletions.
38 changes: 36 additions & 2 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,44 @@
"dfactor": 8,
},
},
"dedode": {
"output": "feats-dedode-n5000-r1600",
"dedodeg": {
"output": "feats-dedodeg-n5000-r1600",
"model": {
"name": "dedode",
"detector_model": "L-upright",
"descriptor_model": "G-upright",
"max_keypoints": 5000,
},
"preprocessing": {
"grayscale": False,
"force_resize": True,
"resize_max": 1600,
"width": 768,
"height": 768,
"dfactor": 8,
},
},
"dedodeg-v2": {
"output": "feats-dedodegv2-n5000-r1600",
"model": {
"name": "dedode",
"detector_model": "L-C4-v2",
"descriptor_model": "G-upright",
"max_keypoints": 5000,
},
"preprocessing": {
"grayscale": False,
"force_resize": True,
"resize_max": 1600,
"width": 768,
"height": 768,
"dfactor": 8,
},
},
"mickey": {
"output": "feats-mickey",
"model": {
"name": "mickey",
"max_keypoints": 5000,
},
"preprocessing": {
Expand Down
67 changes: 9 additions & 58 deletions hloc/extractors/dedode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@
from hloc import MODEL_REPO_ID, logger

from ..utils.base_model import BaseModel

dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
sys.path.append(str(dedode_path))

from DeDoDe import dedode_descriptor_B, dedode_detector_L
from DeDoDe.utils import to_pixel_coords
from kornia import feature as KF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DeDoDe(BaseModel):
default_conf = {
"name": "dedode",
"model_detector_name": "dedode_detector_L.pth",
"model_descriptor_name": "dedode_descriptor_B.pth",
"detector_model": "L-upright",
"descriptor_model": "G-upright",
"max_keypoints": 2000,
"match_threshold": 0.2,
"dense": False, # Now fixed to be false
Expand All @@ -32,34 +27,7 @@ class DeDoDe(BaseModel):

# Initialize the line matcher
def _init(self, conf):
model_detector_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, conf["model_detector_name"]
),
)
model_descriptor_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, conf["model_descriptor_name"]
),
)
logger.info("Loaded DarkFeat model: {}".format(model_detector_path))
self.normalizer = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

# load the model
weights_detector = torch.load(model_detector_path, map_location="cpu")
weights_descriptor = torch.load(
model_descriptor_path, map_location="cpu"
)
self.detector = dedode_detector_L(
weights=weights_detector, device=device
)
self.descriptor = dedode_descriptor_B(
weights=weights_descriptor, device=device
)
self.model = KF.DeDoDe.from_pretrained(detector_weights=conf["detector_model"], descriptor_weights=conf["descriptor_model"])
logger.info("Load DeDoDe model done.")

def _forward(self, data):
Expand All @@ -68,29 +36,12 @@ def _forward(self, data):
image shape: N x C x H x W
color mode: RGB
"""
img0 = self.normalizer(data["image"].squeeze()).float()[None]
H_A, W_A = img0.shape[2:]

# step 1: detect keypoints
detections_A = None
batch_A = {"image": img0}
if self.conf["dense"]:
detections_A = self.detector.detect_dense(batch_A)
else:
detections_A = self.detector.detect(
batch_A, num_keypoints=self.conf["max_keypoints"]
)
keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
img0 = data["image"].float()

# step 2: describe keypoints
# dim: 1 x N x 256
description_A = self.descriptor.describe_keypoints(
batch_A, keypoints_A
)["descriptions"]
keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)
coords, scores, descriptions = self.model(img0, n=self.conf["max_keypoints"])

return {
"keypoints": keypoints_A, # 1 x N x 2
"descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N
"scores": P_A, # 1 x N
"keypoints": coords, # 1 x N x 2
"descriptors": descriptions.permute(0, 2, 1), # 1 x 512 x N
"scores": scores, # 1 x N
}
32 changes: 32 additions & 0 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@
"force_resize": False,
},
},
"dedode-lightglue": {
"output": "matches-dedode-lightglue",
"model": {
"name": "lightglue",
"match_threshold": 0.2,
"width_confidence": 0.99, # for point pruning
"depth_confidence": 0.95, # for early stopping,
"features": "dedodeg",
},
"preprocessing": {
"grayscale": True,
"resize_max": 1024,
"dfactor": 8,
"force_resize": False,
},
},
"dedodev2-lightglue": {
"output": "matches-dedode-lightglue",
"model": {
"name": "lightglue",
"match_threshold": 0.2,
"width_confidence": 0.99, # for point pruning
"depth_confidence": 0.95, # for early stopping,
"features": "dedodeg-v2",
},
"preprocessing": {
"grayscale": True,
"resize_max": 1024,
"dfactor": 8,
"force_resize": False,
},
},
"sift-lightglue": {
"output": "matches-sift-lightglue",
"model": {
Expand Down
29 changes: 10 additions & 19 deletions hloc/matchers/lightglue.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import sys
from pathlib import Path
import torch
import kornia.feature as KF

from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
sys.path.append(str(lightglue_path))
from lightglue import LightGlue as LG


class LightGlue(BaseModel):
default_conf = {
"match_threshold": 0.2,
Expand All @@ -33,22 +28,17 @@ class LightGlue(BaseModel):
]

def _init(self, conf):
logger.info("Loading lightglue model, {}".format(conf["model_name"]))
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
conf["weights"] = str(model_path)
conf["filter_threshold"] = conf["match_threshold"]
self.net = LG(**conf)
features = conf.pop("features")
logger.info("Loading lightglue model, {}".format(features))

self.net = KF.LightGlue(features, **conf)
logger.info("Load lightglue model done.")

def _forward(self, data):
input = {}
input["image0"] = {
"image": data["image0"],
# "image": data["image0"],
"image_size": torch.tensor(data["image0"].shape[-2:][::-1])[None],
"keypoints": data["keypoints0"],
"descriptors": data["descriptors0"].permute(0, 2, 1),
}
Expand All @@ -58,7 +48,8 @@ def _forward(self, data):
input["image0"] = {**input["image0"], "oris": data["oris0"]}

input["image1"] = {
"image": data["image1"],
# "image": data["image1"],
"image_size": torch.tensor(data["image1"].shape[-2:][::-1])[None],
"keypoints": data["keypoints1"],
"descriptors": data["descriptors1"].permute(0, 2, 1),
}
Expand Down
26 changes: 24 additions & 2 deletions ui/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,31 @@ matcher_zoo:
paper: https://arxiv.org/abs/2404.19174
project: null
display: false
dedode:
dedodev2:
matcher: Dual-Softmax
feature: dedode
feature: dedodeg-v2
dense: false
info:
name: DeDoDeV2 #dispaly name
source: "3DV 2024"
github: https://github.com/Parskatt/DeDoDe
paper: https://arxiv.org/abs/2308.08479
project: null
display: true
dedode+lightglue:
matcher: dedode-lightglue
feature: dedodeg
dense: false
info:
name: DeDoDe #dispaly name
source: "3DV 2024"
github: https://github.com/Parskatt/DeDoDe
paper: https://arxiv.org/abs/2308.08479
project: null
display: true
dedodev2+lightglue:
matcher: dedode-lightglue
feature: dedodeg-v2
dense: false
info:
name: DeDoDe #dispaly name
Expand Down

0 comments on commit 4e8978b

Please sign in to comment.