Skip to content

Commit

Permalink
Add LBD, L2D2 and LineTR (cvg#13)
Browse files Browse the repository at this point in the history
* add LBD, L2D2, LineTR.

* remove download scripts for pretrained models.

* minor. fix module name for sold2 matcher.
  • Loading branch information
B1ueber2y authored Nov 19, 2022
1 parent 0f7e57f commit 62adcfe
Show file tree
Hide file tree
Showing 26 changed files with 1,501 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*.swp
*.zip
*.tar.gz
*.th

experiments
build
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@
[submodule "third-party/pytlsd"]
path = third-party/pytlsd
url = [email protected]:iago-suarez/pytlsd.git
[submodule "third-party/pytlbd"]
path = third-party/pytlbd
url = [email protected]:iago-suarez/pytlbd.git
70 changes: 70 additions & 0 deletions limap/line2d/L2D2/RAL_net_cov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import division, print_function
import torch
import torch.nn.init
import torch.nn as nn


class L2Norm(nn.Module):
def __init__(self):
super(L2Norm,self).__init__()
self.eps = 1e-10
def forward(self, x):
norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps)
x= x / norm.unsqueeze(-1).expand_as(x)
return x


class L2Net(nn.Module):
def __init__(self):
super(L2Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(32, affine=False),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(32, affine=False),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=(4,3), stride=2, padding=1, bias = False),#3
nn.BatchNorm2d(64, affine=False),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(64, affine=False),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=(4,3), stride=2,padding=1, bias = False),#3
nn.BatchNorm2d(128, affine=False),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(128, affine=False),
nn.ReLU(),
nn.Dropout(0.3),
nn.Conv2d(128, 128, kernel_size=(12,8), bias = False),#8
nn.BatchNorm2d(128, affine=False),

)
self.features.apply(weights_init)
return

def input_norm(self,x):
flat = x.view(x.size(0), -1)
mp = torch.mean(flat, dim=1)
sp = torch.std(flat, dim=1) + 1e-7
return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x)

def forward(self, input):
x_features = self.features(self.input_norm(input))
x = x_features.view(x_features.size(0), -1)

return L2Norm()(x)

def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight.data, gain=0.6)
try:
nn.init.constant_(m.bias.data, 0.01)

except:
pass
return

def get_net():
return L2Net()
2 changes: 2 additions & 0 deletions limap/line2d/L2D2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extractor import L2D2Extractor
from .matcher import L2D2Matcher
118 changes: 118 additions & 0 deletions limap/line2d/L2D2/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os, sys
import numpy as np
import cv2
import torch

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from base_detector import BaseDetector, BaseDetectorOptions

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import limap.util.io as limapio

sys.path.append(os.path.dirname(__file__))
from .RAL_net_cov import get_net


class L2D2Extractor(BaseDetector):
def __init__(self, options = BaseDetectorOptions(), device=None):
super(L2D2Extractor, self).__init__(options)
self.mini_batch = 20
self.device = 'cuda' if device is None else device
ckpt = os.path.join(os.path.dirname(__file__),
'checkpoint_line_descriptor.th')
if not os.path.isfile(ckpt):
self.download_model(ckpt)
self.model = torch.load(ckpt).to(self.device)
self.model.eval()

def download_model(self, path):
import subprocess
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
link = "https://github.com/hichem-abdellali/L2D2/blob/main/IN_OUT_DATA/INPUT_NETWEIGHT/checkpoint_line_descriptor.th?raw=true"
cmd = ["wget", link, "-O", path]
print("Downloading L2D2 model...")
subprocess.run(cmd, check=True)

def get_module_name(self):
return "l2d2"

def get_descinfo_fname(self, descinfo_folder, img_id):
fname = os.path.join(descinfo_folder, "descinfo_{0}.npz".format(img_id))
return fname

def save_descinfo(self, descinfo_folder, img_id, descinfo):
limapio.check_makedirs(descinfo_folder)
fname = self.get_descinfo_fname(descinfo_folder, img_id)
limapio.save_npz(fname, descinfo)

def read_descinfo(self, descinfo_folder, img_id):
fname = self.get_descinfo_fname(descinfo_folder, img_id)
descinfo = limapio.read_npz(fname)
return descinfo

def extract(self, camview, segs):
img = camview.read_image(set_gray=self.set_gray)
descinfo = self.compute_descinfo(img, segs)
return descinfo

def get_patch(self, img, line):
""" Extract a 48x32 patch around a line [2, 2]. """
h, w = img.shape

# Keep a consistent endpoint ordering
if line[1, 1] < line[0, 1]:
line = line[[1, 0]]

# Get the rotation angle
angle = np.arctan2(line[1, 0] - line[0, 0], line[1, 1] - line[0, 1])

# Compute the affine transform to center and rotate the line
midpoint = line.mean(axis=0)
T_midpoint_to_origin = np.array([[1., 0., -midpoint[0]],
[0., 1., -midpoint[1]],
[0., 0., 1.]])
T_rot = np.array([[np.cos(angle), -np.sin(angle), 0.],
[np.sin(angle), np.cos(angle), 0.],
[0., 0., 1.]])
T_origin_to_center = np.array([[1., 0., w // 2],
[0., 1., h // 2],
[0., 0., 1.]])
A = T_origin_to_center @ T_rot @ T_midpoint_to_origin

# Translate and rotate the image
patch = cv2.warpAffine(img, A[:2], (w, h))

# Crop and resize the patch
length = np.linalg.norm(line[0] - line[1])
new_h = max(int(np.round(length)), 5) # use a minimum height of 5 for short segments
new_w = new_h * 32 // 48
patch = patch[h // 2 - new_h // 2: h // 2 + new_h // 2,
w // 2 - new_w // 2: w // 2 + new_w // 2]
patch = cv2.resize(patch, (32, 48))
return patch

def compute_descinfo(self, img, segs):
""" A desc_info is composed of the following tuple / np arrays:
- the line descriptors [N, 128]
"""
# Extract patches and compute a line descriptor for each patch
lines = segs.reshape(-1, 2, 2)
if len(lines) == 0:
return {'line_descriptors': np.empty((0, 128))}

patches, line_desc = [], []
for i, l in enumerate(lines):
patches.append(self.get_patch(img, l))

if ((i + 1) % self.mini_batch == 0
or i == len(lines) - 1):
# Extract the descriptors
patches = torch.tensor(np.array(patches), dtype=torch.float,
device=self.device)[:, None] / 255.
patches = (patches - 0.492967568115862) / 0.272086182765434
with torch.no_grad():
line_desc.append(self.model(patches))
patches = []
line_desc = torch.cat(line_desc, dim=0) # [n_lines, 128]
return {'line_descriptors': line_desc.cpu().numpy()}
62 changes: 62 additions & 0 deletions limap/line2d/L2D2/matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os, sys
import numpy as np

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from base_matcher import BaseMatcher, BaseMatcherOptions


class L2D2Matcher(BaseMatcher):
def __init__(self, extractor, options = BaseMatcherOptions()):
super(L2D2Matcher, self).__init__(extractor, options)

def get_module_name(self):
return "l2d2"

def check_compatibility(self, extractor):
return extractor.get_module_name() == "l2d2"

def match_pair(self, descinfo1, descinfo2):
if self.topk == 0:
return self.match_segs_with_descinfo(descinfo1, descinfo2)
else:
return self.match_segs_with_descinfo_topk(descinfo1, descinfo2, topk=self.topk)

def match_segs_with_descinfo(self, descinfo1, descinfo2):
desc1 = descinfo1['line_descriptors']
desc2 = descinfo2['line_descriptors']

# Default case when an image has no lines
if len(desc1) == 0 or len(desc2) == 0:
return np.empty((0, 2))

# Mutual nearest neighbor matching
score_mat = desc1 @ desc2.T
nearest1 = np.argmax(score_mat, axis=1)
nearest2 = np.argmax(score_mat, axis=0)
mutual = nearest2[nearest1] == np.arange(len(desc1))
nearest1[~mutual] = -1

# Transform matches to [n_matches, 2]
id_list_1 = np.arange(0, len(nearest1))[mutual]
id_list_2 = nearest1[mutual]
matches_t = np.stack([id_list_1, id_list_2], 1)
return matches_t

def match_segs_with_descinfo_topk(self, descinfo1, descinfo2, topk=10):
desc1 = descinfo1['line_descriptors']
desc2 = descinfo2['line_descriptors']

# Default case when an image has no lines
if len(desc1) == 0 or len(desc2) == 0:
return np.empty((0, 2))

# Top k nearest neighbor matching
score_mat = desc1 @ desc2.T
matches = np.argsort(score_mat, axis=1)[:, -topk:]
matches = np.flip(matches, axis=1)

# Transform matches to [n_matches, 2]
n_lines = len(matches)
matches_t = np.stack([np.arange(n_lines).repeat(topk),
matches.flatten()], axis=1)
return matches_t
2 changes: 2 additions & 0 deletions limap/line2d/LBD/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extractor import LBDExtractor
from .matcher import LBDMatcher
84 changes: 84 additions & 0 deletions limap/line2d/LBD/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os, sys
import numpy as np
import cv2

import pytlsd
import pytlbd

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from base_detector import BaseDetector, BaseDetectorOptions

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import limap.util.io as limapio


def process_pyramid(img, detector, n_levels=5, level_scale=np.sqrt(2), presmooth=True):
octave_img = img.copy()
pre_sigma2 = 0
cur_sigma2 = 1.0
pyramid = []
multiscale_segs = []
for i in range(n_levels):
increase_sigma = np.sqrt(cur_sigma2 - pre_sigma2)
blurred = cv2.GaussianBlur(octave_img, (5, 5), increase_sigma, borderType=cv2.BORDER_REPLICATE)
pyramid.append(blurred)

if presmooth:
multiscale_segs.append(detector(blurred))
else:
multiscale_segs.append(detector(octave_img))

# cv2.imshow(f"Mine L{i}", blurred)
# down sample the current octave image to get the next octave image
new_size = (int(octave_img.shape[1] / level_scale), int(octave_img.shape[0] / level_scale))
octave_img = cv2.resize(blurred, new_size, 0, 0, interpolation=cv2.INTER_NEAREST)
pre_sigma2 = cur_sigma2
cur_sigma2 = cur_sigma2 * 2

return multiscale_segs, pyramid


def to_multiscale_lines(lines):
ms_lines = []
for l in lines.reshape(-1, 4):
ll = np.append(l, [0, np.linalg.norm(l[:2] - l[2:4])])
ms_lines.append([(0, ll)] + [(i, ll / (i * np.sqrt(2))) for i in range(1, 5)])
return ms_lines


class LBDExtractor(BaseDetector):
def __init__(self, options = BaseDetectorOptions()):
super(LBDExtractor, self).__init__(options)

def get_module_name(self):
return "lbd"

def get_descinfo_fname(self, descinfo_folder, img_id):
fname = os.path.join(descinfo_folder, "descinfo_{0}.npz".format(img_id))
return fname

def save_descinfo(self, descinfo_folder, img_id, descinfo):
limapio.check_makedirs(descinfo_folder)
fname = self.get_descinfo_fname(descinfo_folder, img_id)
limapio.save_npz(fname, descinfo)

def read_descinfo(self, descinfo_folder, img_id):
fname = self.get_descinfo_fname(descinfo_folder, img_id)
descinfo = limapio.read_npz(fname)
return descinfo

def extract(self, camview, segs):
img = camview.read_image(set_gray=self.set_gray)
descinfo = self.compute_descinfo(img, segs)
return descinfo

def compute_descinfo(self, img, segs):
""" A desc_info is composed of the following tuple / np arrays:
- the multiscale lines [N, 5] containing tuples of (scale, scaled_line)
- the line descriptors [N, dim]
"""
ms_lines = to_multiscale_lines(segs)
_, pyramid = process_pyramid(img, pytlsd.lsd, presmooth=False)
descriptors = pytlbd.lbd_multiscale_pyr(pyramid, ms_lines, 9, 7)

return {'ms_lines': ms_lines, 'line_descriptors': descriptors}
38 changes: 38 additions & 0 deletions limap/line2d/LBD/matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os, sys
import numpy as np

import pytlbd

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from base_matcher import BaseMatcher, BaseMatcherOptions

class LBDMatcher(BaseMatcher):
def __init__(self, extractor, options = BaseMatcherOptions()):
super(LBDMatcher, self).__init__(extractor, options)

def get_module_name(self):
return "lbd"

def check_compatibility(self, extractor):
return extractor.get_module_name() == "lbd"

def match_pair(self, descinfo1, descinfo2):
if self.topk == 0:
return self.match_segs_with_descinfo(descinfo1, descinfo2)
else:
return self.match_segs_with_descinfo_topk(descinfo1, descinfo2, topk=self.topk)

def match_segs_with_descinfo(self, descinfo1, descinfo2):
try:
matches = pytlbd.lbd_matching_multiscale(
descinfo1['ms_lines'].tolist(),
descinfo2['ms_lines'].tolist(),
descinfo1['line_descriptors'].tolist(),
descinfo2['line_descriptors'].tolist())
matches = np.array(matches)[:, :2]
except RuntimeError:
matches = np.zeros((0, 2))
return matches

def match_segs_with_descinfo_topk(self, descinfo1, descinfo2, topk=10):
raise NotImplementedError()
Loading

0 comments on commit 62adcfe

Please sign in to comment.