forked from cvg/Hierarchical-Localization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nearest_neighbor.py
62 lines (53 loc) · 2.24 KB
/
nearest_neighbor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from ..utils.base_model import BaseModel
def find_nn(sim, ratio_thresh, distance_thresh):
sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
dist_nn = 2 * (1 - sim_nn)
mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
if ratio_thresh:
mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2)*dist_nn[..., 1])
if distance_thresh:
mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
scores = torch.where(mask, (sim_nn[..., 0]+1)/2, sim_nn.new_tensor(0))
return matches, scores
def mutual_check(m0, m1):
inds0 = torch.arange(m0.shape[-1], device=m0.device)
loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
ok = (m0 > -1) & (inds0 == loop)
m0_new = torch.where(ok, m0, m0.new_tensor(-1))
return m0_new
class NearestNeighbor(BaseModel):
default_conf = {
'ratio_threshold': None,
'distance_threshold': None,
'do_mutual_check': True,
}
required_inputs = ['descriptors0', 'descriptors1']
def _init(self, conf):
pass
def _forward(self, data):
if data['descriptors0'].size(-1) == 0 or data['descriptors1'].size(-1) == 0:
matches0 = torch.full(
data['descriptors0'].shape[:2], -1,
device=data['descriptors0'].device)
return {
'matches0': matches0,
'matching_scores0': torch.zeros_like(matches0)
}
ratio_threshold = self.conf['ratio_threshold']
if data['descriptors0'].size(-1) == 1 or data['descriptors1'].size(-1) == 1:
ratio_threshold = None
sim = torch.einsum(
'bdn,bdm->bnm', data['descriptors0'], data['descriptors1'])
matches0, scores0 = find_nn(
sim, ratio_threshold, self.conf['distance_threshold'])
if self.conf['do_mutual_check']:
matches1, scores1 = find_nn(
sim.transpose(1, 2), ratio_threshold,
self.conf['distance_threshold'])
matches0 = mutual_check(matches0, matches1)
return {
'matches0': matches0,
'matching_scores0': scores0,
}