From dceccb6ad123762e1ae7646329e89c1dca578bb7 Mon Sep 17 00:00:00 2001 From: Realcat Date: Sun, 7 Jul 2024 11:04:57 +0000 Subject: [PATCH] update: device --- common/config.yaml | 2 ++ common/utils.py | 10 +++++----- hloc/__init__.py | 3 ++- test_app_cli.py | 8 ++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/common/config.yaml b/common/config.yaml index b50682d..1160ddd 100644 --- a/common/config.yaml +++ b/common/config.yaml @@ -372,6 +372,7 @@ matcher_zoo: sfd2+imp: matcher: imp feature: sfd2 + enable: false dense: false skip_ci: true info: @@ -385,6 +386,7 @@ matcher_zoo: sfd2+mnn: matcher: NN-mutual feature: sfd2 + enable: false dense: false skip_ci: true info: diff --git a/common/utils.py b/common/utils.py index a171f70..e01c459 100644 --- a/common/utils.py +++ b/common/utils.py @@ -24,13 +24,13 @@ match_dense, match_features, matchers, + DEVICE ) from hloc.utils.base_model import dynamic_load from .viz import display_keypoints, display_matches, fig2im, plot_images warnings.simplefilter("ignore") -device = "cuda" if torch.cuda.is_available() else "cpu" ROOT = Path(__file__).parent.parent # some default values @@ -173,7 +173,7 @@ def get_model(match_conf: Dict[str, Any]): A matcher model instance. """ Model = dynamic_load(matchers, match_conf["model"]["name"]) - model = Model(match_conf["model"]).eval().to(device) + model = Model(match_conf["model"]).eval().to(DEVICE) return model @@ -188,7 +188,7 @@ def get_feature_model(conf: Dict[str, Dict[str, Any]]): A feature extraction model instance. """ Model = dynamic_load(extractors, conf["model"]["name"]) - model = Model(conf["model"]).eval().to(device) + model = Model(conf["model"]).eval().to(DEVICE) return model @@ -879,7 +879,7 @@ def run_matching( output_matches_ransac = None # super slow! - if "roma" in key.lower() and device == "cpu": + if "roma" in key.lower() and DEVICE == "cpu": gr.Info( f"Success! Please be patient and allow for about 2-3 minutes." f" Due to CPU inference, {key} is quiet slow." @@ -904,7 +904,7 @@ def run_matching( if model["dense"]: pred = match_dense.match_images( - matcher, image0, image1, match_conf["preprocessing"], device=device + matcher, image0, image1, match_conf["preprocessing"], device=DEVICE ) del matcher extract_conf = None diff --git a/hloc/__init__.py b/hloc/__init__.py index c9f9d95..7c2e3dd 100644 --- a/hloc/__init__.py +++ b/hloc/__init__.py @@ -32,4 +32,5 @@ minimal_version, found_version, ) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/test_app_cli.py b/test_app_cli.py index e3e30b6..e9ca772 100644 --- a/test_app_cli.py +++ b/test_app_cli.py @@ -6,7 +6,7 @@ from common.utils import ( get_matcher_zoo, load_config, - device, + DEVICE, ROOT, ) from common.api import ImageMatchingAPI @@ -26,7 +26,7 @@ def test_all(config: dict = None): skip_ci = config["matcher_zoo"][k].get("skip_ci", False) if enable and not skip_ci: logger.info(f"Testing {k} ...") - api = ImageMatchingAPI(conf=v, device=device) + api = ImageMatchingAPI(conf=v, device=DEVICE) api(image0, image1) log_path = ROOT / "experiments" / "all" log_path.mkdir(exist_ok=True, parents=True) @@ -70,7 +70,7 @@ def test_one(): }, "dense": False, } - api = ImageMatchingAPI(conf=conf, device=device) + api = ImageMatchingAPI(conf=conf, device=DEVICE) api(image0, image1) log_path = ROOT / "experiments" / "one" log_path.mkdir(exist_ok=True, parents=True) @@ -100,7 +100,7 @@ def test_one(): "dense": True, } - api = ImageMatchingAPI(conf=conf, device=device) + api = ImageMatchingAPI(conf=conf, device=DEVICE) api(image0, image1) log_path = ROOT / "experiments" / "one" log_path.mkdir(exist_ok=True, parents=True)