diff --git a/.gitmodules b/.gitmodules index e4ee749..d61f3e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -55,9 +55,6 @@ [submodule "third_party/RoMa"] path = third_party/RoMa url = https://github.com/Vincentqyw/RoMa.git -[submodule "third_party/pram"] - path = third_party/pram - url = https://github.com/feixue94/pram.git [submodule "third_party/SGMNet"] path = third_party/SGMNet url = https://github.com/agipro/SGMNet.git @@ -67,3 +64,6 @@ [submodule "third_party/mast3r"] path = third_party/mast3r url = https://github.com/naver/mast3r +[submodule "third_party/pram"] + path = third_party/pram + url = https://github.com/agipro/pram.git diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 9fb76ed..c7ce322 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ from .. import logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.sfd2 import load_sfd2 +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 class SFD2(BaseModel): @@ -26,7 +24,7 @@ def _init(self, conf): self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - model_fn = pram_path / "weights" / self.conf["model_name"] + model_fn = tp_path / "pram" / "weights" / self.conf["model_name"] self.net = load_sfd2(weight_path=model_fn).eval() logger.info("Load SFD2 model done.") diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index ca64980..05c3cb9 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ from .. import DEVICE, logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.gml import GML +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML class IMP(BaseModel): @@ -33,7 +31,8 @@ class IMP(BaseModel): def _init(self, conf): self.conf = {**self.default_conf, **conf} - weight_path = pram_path / "weights" / self.conf["model_name"] + weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + # self.net = nets.gml(self.conf).eval().to(DEVICE) self.net = GML(self.conf).eval().to(DEVICE) self.net.load_state_dict( torch.load(weight_path, map_location="cpu")["model"], strict=True diff --git a/third_party/pram b/third_party/pram index 96929f8..742ff42 160000 --- a/third_party/pram +++ b/third_party/pram @@ -1 +1 @@ -Subproject commit 96929f8c7f10b158036e078f7cc5363c4a2dcbcb +Subproject commit 742ff4241105bfaee0039e01556c79fb9be5c8b8 diff --git a/ui/config.yaml b/ui/config.yaml index 63fe9d1..741a79a 100644 --- a/ui/config.yaml +++ b/ui/config.yaml @@ -381,8 +381,7 @@ matcher_zoo: display: true sfd2+imp: - enable: false - skip_ci: true + enable: true matcher: imp feature: sfd2 dense: false @@ -395,8 +394,7 @@ matcher_zoo: display: true sfd2+mnn: - enable: false - skip_ci: true + enable: true matcher: NN-mutual feature: sfd2 dense: false diff --git a/ui/utils.py b/ui/utils.py index 8c66f12..270d12a 100644 --- a/ui/utils.py +++ b/ui/utils.py @@ -910,8 +910,9 @@ def run_matching( t1 = time.time() if model["dense"]: - if not match_conf["preprocessing"]["force_resize"]: + if not match_conf["preprocessing"].get("force_resize", False): match_conf["preprocessing"]["force_resize"] = force_resize + else: logger.info("preprocessing is already resized") if force_resize: match_conf["preprocessing"]["height"] = image_height @@ -941,7 +942,7 @@ def run_matching( else: extractor = get_feature_model(extract_conf) - if not extract_conf["preprocessing"]["force_resize"]: + if not extract_conf["preprocessing"].get("force_resize", False): extract_conf["preprocessing"]["force_resize"] = force_resize else: logger.info("preprocessing is already resized")