Skip to content

Commit

Permalink
Add: mast3r (#48)
Browse files Browse the repository at this point in the history
* add: sfm

* add: mast3r
  • Loading branch information
Vincentqyw authored Jul 16, 2024
1 parent a229212 commit b1baa6d
Show file tree
Hide file tree
Showing 47 changed files with 3,515 additions and 699 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ third_party/QuadTreeAttention
desktop.ini
*.egg-info
output.pkl
experiments*
gen_example.py
datasets/lines/terrace0.JPG
datasets/lines/terrace1.JPG
datasets/South-Building*
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@
[submodule "third_party/DarkFeat"]
path = third_party/DarkFeat
url = https://github.com/agipro/DarkFeat.git
[submodule "third_party/mast3r"]
path = third_party/mast3r
url = https://github.com/naver/mast3r
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ENV PATH /opt/conda/envs/imw/bin:$PATH
# Make RUN commands use the new environment
SHELL ["conda", "run", "-n", "imw", "/bin/bash", "-c"]
RUN pip install --upgrade pip
RUN pip install -r env-docker.txt
RUN pip install -r requirements.txt
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y

# Export port
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Here is a demo of the tool:
https://github.com/Vincentqyw/image-matching-webui/assets/18531182/263534692-c3484d1b-cc00-4fdc-9b31-e5b7af07ecd9

The tool currently supports various popular image matching algorithms, namely:
- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024
- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
Expand Down
24 changes: 14 additions & 10 deletions hloc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from packaging import version

__version__ = "1.3"
__version__ = "1.5"

formatter = logging.Formatter(
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
Expand All @@ -23,14 +23,18 @@
except ImportError:
logger.warning("pycolmap is not installed, some features may not work.")
else:
minimal_version = version.parse("0.3.0")
found_version = version.parse(getattr(pycolmap, "__version__"))
if found_version < minimal_version:
logger.warning(
"hloc now requires pycolmap>=%s but found pycolmap==%s, "
"please upgrade with `pip install --upgrade pycolmap`",
minimal_version,
found_version,
)
min_version = version.parse("0.6.0")
found_version = pycolmap.__version__
if found_version != "dev":
version = version.parse(found_version)
if version < min_version:
s = f"pycolmap>={min_version}"
logger.warning(
"hloc requires %s but found pycolmap==%s, "
'please upgrade with `pip install --upgrade "%s"`',
s,
found_version,
s,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
220 changes: 220 additions & 0 deletions hloc/colmap_from_nvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import argparse
import sqlite3
from collections import defaultdict
from pathlib import Path

import numpy as np
from tqdm import tqdm

from . import logger
from .utils.read_write_model import (
CAMERA_MODEL_NAMES,
Camera,
Image,
Point3D,
write_model,
)


def recover_database_images_and_ids(database_path):
images = {}
cameras = {}
db = sqlite3.connect(str(database_path))
ret = db.execute("SELECT name, image_id, camera_id FROM images;")
for name, image_id, camera_id in ret:
images[name] = image_id
cameras[name] = camera_id
db.close()
logger.info(
f"Found {len(images)} images and {len(cameras)} cameras in database."
)
return images, cameras


def quaternion_to_rotation_matrix(qvec):
qvec = qvec / np.linalg.norm(qvec)
w, x, y, z = qvec
R = np.array(
[
[
1 - 2 * y * y - 2 * z * z,
2 * x * y - 2 * z * w,
2 * x * z + 2 * y * w,
],
[
2 * x * y + 2 * z * w,
1 - 2 * x * x - 2 * z * z,
2 * y * z - 2 * x * w,
],
[
2 * x * z - 2 * y * w,
2 * y * z + 2 * x * w,
1 - 2 * x * x - 2 * y * y,
],
]
)
return R


def camera_center_to_translation(c, qvec):
R = quaternion_to_rotation_matrix(qvec)
return (-1) * np.matmul(R, c)


def read_nvm_model(
nvm_path, intrinsics_path, image_ids, camera_ids, skip_points=False
):
with open(intrinsics_path, "r") as f:
raw_intrinsics = f.readlines()

logger.info(f"Reading {len(raw_intrinsics)} cameras...")
cameras = {}
for intrinsics in raw_intrinsics:
intrinsics = intrinsics.strip("\n").split(" ")
name, camera_model, width, height = intrinsics[:4]
params = [float(p) for p in intrinsics[4:]]
camera_model = CAMERA_MODEL_NAMES[camera_model]
assert len(params) == camera_model.num_params
camera_id = camera_ids[name]
camera = Camera(
id=camera_id,
model=camera_model.model_name,
width=int(width),
height=int(height),
params=params,
)
cameras[camera_id] = camera

nvm_f = open(nvm_path, "r")
line = nvm_f.readline()
while line == "\n" or line.startswith("NVM_V3"):
line = nvm_f.readline()
num_images = int(line)
assert num_images == len(cameras)

logger.info(f"Reading {num_images} images...")
image_idx_to_db_image_id = []
image_data = []
i = 0
while i < num_images:
line = nvm_f.readline()
if line == "\n":
continue
data = line.strip("\n").split(" ")
image_data.append(data)
image_idx_to_db_image_id.append(image_ids[data[0]])
i += 1

line = nvm_f.readline()
while line == "\n":
line = nvm_f.readline()
num_points = int(line)

if skip_points:
logger.info(f"Skipping {num_points} points.")
num_points = 0
else:
logger.info(f"Reading {num_points} points...")
points3D = {}
image_idx_to_keypoints = defaultdict(list)
i = 0
pbar = tqdm(total=num_points, unit="pts")
while i < num_points:
line = nvm_f.readline()
if line == "\n":
continue

data = line.strip("\n").split(" ")
x, y, z, r, g, b, num_observations = data[:7]
obs_image_ids, point2D_idxs = [], []
for j in range(int(num_observations)):
s = 7 + 4 * j
img_index, kp_index, kx, ky = data[s : s + 4]
image_idx_to_keypoints[int(img_index)].append(
(int(kp_index), float(kx), float(ky), i)
)
db_image_id = image_idx_to_db_image_id[int(img_index)]
obs_image_ids.append(db_image_id)
point2D_idxs.append(kp_index)

point = Point3D(
id=i,
xyz=np.array([x, y, z], float),
rgb=np.array([r, g, b], int),
error=1.0, # fake
image_ids=np.array(obs_image_ids, int),
point2D_idxs=np.array(point2D_idxs, int),
)
points3D[i] = point

i += 1
pbar.update(1)
pbar.close()

logger.info("Parsing image data...")
images = {}
for i, data in enumerate(image_data):
# Skip the focal length. Skip the distortion and terminal 0.
name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data
qvec = np.array([qw, qx, qy, qz], float)
c = np.array([cx, cy, cz], float)
t = camera_center_to_translation(c, qvec)

if i in image_idx_to_keypoints:
# NVM only stores triangulated 2D keypoints: add dummy ones
keypoints = image_idx_to_keypoints[i]
point2D_idxs = np.array([d[0] for d in keypoints])
tri_xys = np.array([[x, y] for _, x, y, _ in keypoints])
tri_ids = np.array([i for _, _, _, i in keypoints])

num_2Dpoints = max(point2D_idxs) + 1
xys = np.zeros((num_2Dpoints, 2), float)
point3D_ids = np.full(num_2Dpoints, -1, int)
xys[point2D_idxs] = tri_xys
point3D_ids[point2D_idxs] = tri_ids
else:
xys = np.zeros((0, 2), float)
point3D_ids = np.full(0, -1, int)

image_id = image_ids[name]
image = Image(
id=image_id,
qvec=qvec,
tvec=t,
camera_id=camera_ids[name],
name=name,
xys=xys,
point3D_ids=point3D_ids,
)
images[image_id] = image

return cameras, images, points3D


def main(nvm, intrinsics, database, output, skip_points=False):
assert nvm.exists(), nvm
assert intrinsics.exists(), intrinsics
assert database.exists(), database

image_ids, camera_ids = recover_database_images_and_ids(database)

logger.info("Reading the NVM model...")
model = read_nvm_model(
nvm, intrinsics, image_ids, camera_ids, skip_points=skip_points
)

logger.info("Writing the COLMAP model...")
output.mkdir(exist_ok=True, parents=True)
write_model(*model, path=str(output), ext=".bin")
logger.info("Done.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--nvm", required=True, type=Path)
parser.add_argument("--intrinsics", required=True, type=Path)
parser.add_argument("--database", required=True, type=Path)
parser.add_argument("--output", required=True, type=Path)
parser.add_argument("--skip_points", action="store_true")
args = parser.parse_args()
main(**args.__dict__)
5 changes: 5 additions & 0 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@
"model": {"name": "cosplace"},
"preprocessing": {"resize_max": 1024},
},
"eigenplaces": {
"output": "global-feats-eigenplaces",
"model": {"name": "eigenplaces"},
"preprocessing": {"resize_max": 1024},
},
}


Expand Down
57 changes: 57 additions & 0 deletions hloc/extractors/eigenplaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Code for loading models trained with EigenPlaces (or CosPlace) as a global
features extractor for geolocalization through image retrieval.
Multiple models are available with different backbones. Below is a summary of
models available (backbone : list of available output descriptors
dimensionality). For example you can use a model based on a ResNet50 with
descriptors dimensionality 1024.
EigenPlaces trained models:
ResNet18: [ 256, 512]
ResNet50: [128, 256, 512, 2048]
ResNet101: [128, 256, 512, 2048]
VGG16: [ 512]
CosPlace trained models:
ResNet18: [32, 64, 128, 256, 512]
ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
VGG16: [ 64, 128, 256, 512]
EigenPlaces paper (ICCV 2023): https://arxiv.org/abs/2308.10832
CosPlace paper (CVPR 2022): https://arxiv.org/abs/2204.02287
"""

import torch
import torchvision.transforms as tvf

from ..utils.base_model import BaseModel


class EigenPlaces(BaseModel):
default_conf = {
"variant": "EigenPlaces",
"backbone": "ResNet101",
"fc_output_dim": 2048,
}
required_inputs = ["image"]

def _init(self, conf):
self.net = torch.hub.load(
"gmberton/" + conf["variant"],
"get_trained_model",
backbone=conf["backbone"],
fc_output_dim=conf["fc_output_dim"],
).eval()

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.norm_rgb = tvf.Normalize(mean=mean, std=std)

def _forward(self, data):
image = self.norm_rgb(data["image"])
desc = self.net(image)
return {
"global_descriptor": desc,
}
Loading

0 comments on commit b1baa6d

Please sign in to comment.