Skip to content

Commit

Permalink
Fix: sfd2 (#63)
Browse files Browse the repository at this point in the history
* Fix: #54

* del: official pram

* update: pram

* update: sfd2
  • Loading branch information
Vincentqyw authored Aug 21, 2024
1 parent 07c30b9 commit f791ada
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 20 deletions.
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@
[submodule "third_party/RoMa"]
path = third_party/RoMa
url = https://github.com/Vincentqyw/RoMa.git
[submodule "third_party/pram"]
path = third_party/pram
url = https://github.com/feixue94/pram.git
[submodule "third_party/SGMNet"]
path = third_party/SGMNet
url = https://github.com/agipro/SGMNet.git
Expand All @@ -67,3 +64,6 @@
[submodule "third_party/mast3r"]
path = third_party/mast3r
url = https://github.com/naver/mast3r
[submodule "third_party/pram"]
path = third_party/pram
url = https://github.com/agipro/pram.git
10 changes: 4 additions & 6 deletions hloc/extractors/sfd2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path

Expand All @@ -7,10 +6,9 @@
from .. import logger
from ..utils.base_model import BaseModel

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.sfd2 import load_sfd2
tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from pram.nets.sfd2 import load_sfd2


class SFD2(BaseModel):
Expand All @@ -26,7 +24,7 @@ def _init(self, conf):
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
model_fn = pram_path / "weights" / self.conf["model_name"]
model_fn = tp_path / "pram" / "weights" / self.conf["model_name"]
self.net = load_sfd2(weight_path=model_fn).eval()

logger.info("Load SFD2 model done.")
Expand Down
11 changes: 5 additions & 6 deletions hloc/matchers/imp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path

Expand All @@ -7,10 +6,9 @@
from .. import DEVICE, logger
from ..utils.base_model import BaseModel

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.gml import GML
tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from pram.nets.gml import GML


class IMP(BaseModel):
Expand All @@ -33,7 +31,8 @@ class IMP(BaseModel):

def _init(self, conf):
self.conf = {**self.default_conf, **conf}
weight_path = pram_path / "weights" / self.conf["model_name"]
weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
# self.net = nets.gml(self.conf).eval().to(DEVICE)
self.net = GML(self.conf).eval().to(DEVICE)
self.net.load_state_dict(
torch.load(weight_path, map_location="cpu")["model"], strict=True
Expand Down
2 changes: 1 addition & 1 deletion third_party/pram
Submodule pram updated 1 files
+1 −1 nets/gml.py
6 changes: 2 additions & 4 deletions ui/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ matcher_zoo:
display: true

sfd2+imp:
enable: false
skip_ci: true
enable: true
matcher: imp
feature: sfd2
dense: false
Expand All @@ -395,8 +394,7 @@ matcher_zoo:
display: true

sfd2+mnn:
enable: false
skip_ci: true
enable: true
matcher: NN-mutual
feature: sfd2
dense: false
Expand Down

0 comments on commit f791ada

Please sign in to comment.