From f791ada46722b7d0968e71b540bca7e64fd2ce32 Mon Sep 17 00:00:00 2001 From: Realcat Date: Wed, 21 Aug 2024 23:24:55 +0800 Subject: [PATCH] Fix: sfd2 (#63) * Fix: https://github.com/Vincentqyw/image-matching-webui/issues/54 * del: official pram * update: pram * update: sfd2 --- .gitmodules | 6 +++--- hloc/extractors/sfd2.py | 10 ++++------ hloc/matchers/imp.py | 11 +++++------ third_party/pram | 2 +- ui/config.yaml | 6 ++---- 5 files changed, 15 insertions(+), 20 deletions(-) diff --git a/.gitmodules b/.gitmodules index e4ee749..d61f3e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -55,9 +55,6 @@ [submodule "third_party/RoMa"] path = third_party/RoMa url = https://github.com/Vincentqyw/RoMa.git -[submodule "third_party/pram"] - path = third_party/pram - url = https://github.com/feixue94/pram.git [submodule "third_party/SGMNet"] path = third_party/SGMNet url = https://github.com/agipro/SGMNet.git @@ -67,3 +64,6 @@ [submodule "third_party/mast3r"] path = third_party/mast3r url = https://github.com/naver/mast3r +[submodule "third_party/pram"] + path = third_party/pram + url = https://github.com/agipro/pram.git diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 9fb76ed..c7ce322 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ from .. import logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.sfd2 import load_sfd2 +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 class SFD2(BaseModel): @@ -26,7 +24,7 @@ def _init(self, conf): self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - model_fn = pram_path / "weights" / self.conf["model_name"] + model_fn = tp_path / "pram" / "weights" / self.conf["model_name"] self.net = load_sfd2(weight_path=model_fn).eval() logger.info("Load SFD2 model done.") diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index ca64980..05c3cb9 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ from .. import DEVICE, logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.gml import GML +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML class IMP(BaseModel): @@ -33,7 +31,8 @@ class IMP(BaseModel): def _init(self, conf): self.conf = {**self.default_conf, **conf} - weight_path = pram_path / "weights" / self.conf["model_name"] + weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + # self.net = nets.gml(self.conf).eval().to(DEVICE) self.net = GML(self.conf).eval().to(DEVICE) self.net.load_state_dict( torch.load(weight_path, map_location="cpu")["model"], strict=True diff --git a/third_party/pram b/third_party/pram index 96929f8..742ff42 160000 --- a/third_party/pram +++ b/third_party/pram @@ -1 +1 @@ -Subproject commit 96929f8c7f10b158036e078f7cc5363c4a2dcbcb +Subproject commit 742ff4241105bfaee0039e01556c79fb9be5c8b8 diff --git a/ui/config.yaml b/ui/config.yaml index 63fe9d1..741a79a 100644 --- a/ui/config.yaml +++ b/ui/config.yaml @@ -381,8 +381,7 @@ matcher_zoo: display: true sfd2+imp: - enable: false - skip_ci: true + enable: true matcher: imp feature: sfd2 dense: false @@ -395,8 +394,7 @@ matcher_zoo: display: true sfd2+mnn: - enable: false - skip_ci: true + enable: true matcher: NN-mutual feature: sfd2 dense: false