Skip to content

Commit

Permalink
update: device
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincentqyw committed Jul 7, 2024
1 parent ff75143 commit dceccb6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 2 additions & 0 deletions common/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ matcher_zoo:
sfd2+imp:
matcher: imp
feature: sfd2
enable: false
dense: false
skip_ci: true
info:
Expand All @@ -385,6 +386,7 @@ matcher_zoo:
sfd2+mnn:
matcher: NN-mutual
feature: sfd2
enable: false
dense: false
skip_ci: true
info:
Expand Down
10 changes: 5 additions & 5 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
match_dense,
match_features,
matchers,
DEVICE
)
from hloc.utils.base_model import dynamic_load

from .viz import display_keypoints, display_matches, fig2im, plot_images

warnings.simplefilter("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"

ROOT = Path(__file__).parent.parent
# some default values
Expand Down Expand Up @@ -173,7 +173,7 @@ def get_model(match_conf: Dict[str, Any]):
A matcher model instance.
"""
Model = dynamic_load(matchers, match_conf["model"]["name"])
model = Model(match_conf["model"]).eval().to(device)
model = Model(match_conf["model"]).eval().to(DEVICE)
return model


Expand All @@ -188,7 +188,7 @@ def get_feature_model(conf: Dict[str, Dict[str, Any]]):
A feature extraction model instance.
"""
Model = dynamic_load(extractors, conf["model"]["name"])
model = Model(conf["model"]).eval().to(device)
model = Model(conf["model"]).eval().to(DEVICE)
return model


Expand Down Expand Up @@ -879,7 +879,7 @@ def run_matching(
output_matches_ransac = None

# super slow!
if "roma" in key.lower() and device == "cpu":
if "roma" in key.lower() and DEVICE == "cpu":
gr.Info(
f"Success! Please be patient and allow for about 2-3 minutes."
f" Due to CPU inference, {key} is quiet slow."
Expand All @@ -904,7 +904,7 @@ def run_matching(

if model["dense"]:
pred = match_dense.match_images(
matcher, image0, image1, match_conf["preprocessing"], device=device
matcher, image0, image1, match_conf["preprocessing"], device=DEVICE
)
del matcher
extract_conf = None
Expand Down
3 changes: 2 additions & 1 deletion hloc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
minimal_version,
found_version,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 changes: 4 additions & 4 deletions test_app_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from common.utils import (
get_matcher_zoo,
load_config,
device,
DEVICE,
ROOT,
)
from common.api import ImageMatchingAPI
Expand All @@ -26,7 +26,7 @@ def test_all(config: dict = None):
skip_ci = config["matcher_zoo"][k].get("skip_ci", False)
if enable and not skip_ci:
logger.info(f"Testing {k} ...")
api = ImageMatchingAPI(conf=v, device=device)
api = ImageMatchingAPI(conf=v, device=DEVICE)
api(image0, image1)
log_path = ROOT / "experiments" / "all"
log_path.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_one():
},
"dense": False,
}
api = ImageMatchingAPI(conf=conf, device=device)
api = ImageMatchingAPI(conf=conf, device=DEVICE)
api(image0, image1)
log_path = ROOT / "experiments" / "one"
log_path.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_one():
"dense": True,
}

api = ImageMatchingAPI(conf=conf, device=device)
api = ImageMatchingAPI(conf=conf, device=DEVICE)
api(image0, image1)
log_path = ROOT / "experiments" / "one"
log_path.mkdir(exist_ok=True, parents=True)
Expand Down

0 comments on commit dceccb6

Please sign in to comment.