Skip to content

Commit

Permalink
Auto download models for SOLD2, S2DNet, SuperPoint and SuperGlue (cvg#12
Browse files Browse the repository at this point in the history
)

* auto downloading sold2 model.

* auto downloading models for superpoint and superglue

* auto download for s2dnet model.

* remove download script at the global directory.
  • Loading branch information
B1ueber2y authored Nov 19, 2022
1 parent 300413d commit 0f7e57f
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 11 deletions.
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ python -m pip install -r requirements.txt
cd third-party/Hierarchical-Localization && python -m pip install -e . && cd ../..
```

Pretrained models for [SOLD2](https://github.com/cvg/SOLD2), [S2DNet](https://github.com/germain-hug/S2DNet-Minimal), [SuperPoint](https://github.com/magicleap/SuperGluePretrainedNetwork) and [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) need to be downloaded:
```
bash download_pretrained_models.sh
```

To compile LIMAP:
```bash
sudo apt-get install libhdf5-dev
Expand Down
5 changes: 0 additions & 5 deletions download_pretrained_models.sh

This file was deleted.

13 changes: 12 additions & 1 deletion limap/features/models/s2dnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from PIL import Image
from .base_model import BaseModel
import sys
import os, sys
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
Expand Down Expand Up @@ -134,12 +134,23 @@ def _init(self, conf):

if conf.pretrained == 's2dnet':
path = Path(__file__).parent / 'checkpoints/s2dnet_weights.pth'
if not os.path.isfile(path):
self.download_s2dnet_model(path)
logging.info(f'Loading S2DNet checkpoint at {path}.')
state_dict = torch.load(path, map_location='cpu')['state_dict']
params = self.state_dict()
state_dict = {k: v for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)

def download_s2dnet_model(self, path):
import subprocess
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
link = "https://www.dropbox.com/s/hnv51iwu4hn82rj/s2dnet_weights.pth?dl=0"
cmd = ["wget", link, "-O", path]
print("Downloading S2DNet model...")
subprocess.run(cmd, check=True)

def _forward(self, data):
image = data#data['image']
mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
Expand Down
9 changes: 9 additions & 0 deletions limap/line2d/SOLD2/sold2_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os, sys
import numpy as np
from torch.nn.functional import softmax
import cv2
import torch
from skimage.draw import line
import subprocess

from SOLD2.experiment import load_config
from SOLD2.model.line_matcher import LineMatcher
Expand All @@ -25,6 +27,13 @@ def __init__(self, device=None, cfg_path=None, ckpt_path=None):
self.initialize_line_matcher()

def initialize_line_matcher(self):
if not os.path.isfile(self.ckpt_path):
if not os.path.exists(os.path.dirname(self.ckpt_path)):
os.makedirs(os.path.dirname(self.ckpt_path))
link = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download"
cmd = ['wget', link, '-O', self.ckpt_path]
print("Downloading SOLD2 model...")
subprocess.run(cmd, check=True)
self.line_matcher = LineMatcher(self.cfg['model_cfg'], self.ckpt_path, self.device, self.cfg['line_detector_cfg'], self.cfg['line_matcher_cfg'], False)

def sold2segstosegs(self, segs_sold2):
Expand Down
13 changes: 13 additions & 0 deletions limap/point2d/superglue/superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# --------------------------------------------------------------------*/
# %BANNER_END%

import os, sys
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple
Expand Down Expand Up @@ -191,10 +192,22 @@ def __init__(self, config):
assert self.config['weights'] in ['indoor', 'outdoor']
path = Path(__file__).parent
path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
if not os.path.isfile(path):
self.download_model(path)
self.load_state_dict(torch.load(str(path)))
print('Loaded SuperGlue model (\"{}\" weights)'.format(
self.config['weights']))

def download_model(self, path):
import subprocess
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
model_name = os.path.basename(path)
print("Downloading SuperGlue model {0}...".format(model_name))
link = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/models/weights/{0}?raw=true".format(model_name)
cmd = ["wget", link, "-O", path]
subprocess.run(cmd, check=True)

def forward(self, data):
"""Run SuperGlue on a pair of keypoints and descriptors"""
desc0, desc1 = data['descriptors0'], data['descriptors1']
Expand Down
12 changes: 12 additions & 0 deletions limap/point2d/superpoint/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# --------------------------------------------------------------------*/
# %BANNER_END%

import os, sys
from pathlib import Path
import torch
from torch import nn
Expand Down Expand Up @@ -134,6 +135,8 @@ def __init__(self, config):
kernel_size=1, stride=1, padding=0)

path = Path(__file__).parent / 'weights/superpoint_v1.pth'
if not os.path.isfile(path):
self.download_model(path)
self.load_state_dict(torch.load(str(path)))

mk = self.config['max_keypoints']
Expand All @@ -142,6 +145,15 @@ def __init__(self, config):

print('Loaded SuperPoint model')

def download_model(self, path):
import subprocess
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
link = "https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/superpoint_v1.pth?raw=true"
cmd = ["wget", link, "-O", path]
print("Downloading SuperPoint model...")
subprocess.run(cmd, check=True)

def compute_dense_descriptor(self, data):
""" Compute keypoints, scores, descriptors for image """
# Shared Encoder
Expand Down

0 comments on commit 0f7e57f

Please sign in to comment.