diff --git a/hloc/extractors/dedode.py b/hloc/extractors/dedode.py index 100d1d3..4bb52c3 100644 --- a/hloc/extractors/dedode.py +++ b/hloc/extractors/dedode.py @@ -3,11 +3,11 @@ import torch import torchvision.transforms as transforms +from kornia import feature as KF from hloc import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel -from kornia import feature as KF device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -27,7 +27,10 @@ class DeDoDe(BaseModel): # Initialize the line matcher def _init(self, conf): - self.model = KF.DeDoDe.from_pretrained(detector_weights=conf["detector_model"], descriptor_weights=conf["descriptor_model"]) + self.model = KF.DeDoDe.from_pretrained( + detector_weights=conf["detector_model"], + descriptor_weights=conf["descriptor_model"], + ) logger.info("Load DeDoDe model done.") def _forward(self, data): @@ -38,7 +41,9 @@ def _forward(self, data): """ img0 = data["image"].float() - coords, scores, descriptions = self.model(img0, n=self.conf["max_keypoints"]) + coords, scores, descriptions = self.model( + img0, n=self.conf["max_keypoints"] + ) return { "keypoints": coords, # 1 x N x 2 diff --git a/hloc/matchers/lightglue.py b/hloc/matchers/lightglue.py index e778867..1dfcbe4 100644 --- a/hloc/matchers/lightglue.py +++ b/hloc/matchers/lightglue.py @@ -1,9 +1,10 @@ -import torch import kornia.feature as KF +import torch from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel + class LightGlue(BaseModel): default_conf = { "match_threshold": 0.2,