Skip to content

Commit

Permalink
FIX: model loading bug (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincentqyw authored Aug 21, 2024
1 parent c962765 commit e4a103b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
31 changes: 26 additions & 5 deletions hloc/matchers/cotr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import argparse
import subprocess
import sys
from pathlib import Path

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from torchvision.transforms import ToPILImage

from .. import logger
from ..utils.base_model import BaseModel

sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
cotr_path = Path(__file__).parent / "../../third_party/COTR"

sys.path.append(str(cotr_path))
from COTR.inference.sparse_engine import SparseEngine
from COTR.models import build_model
from COTR.options.options import * # noqa: F403
Expand All @@ -18,7 +23,6 @@
utils_cotr.fix_randomness(0)
torch.set_grad_enabled(False)

cotr_path = Path(__file__).parent / "../../third_party/COTR"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -36,10 +40,27 @@ def _init(self, conf):
set_COTR_arguments(parser) # noqa: F405
opt = parser.parse_args()
opt.command = " ".join(sys.argv)
opt.load_weights_path = str(
cotr_path / conf["weights"] / "checkpoint.pth.tar"
)
model_path = cotr_path / conf["weights"] / "checkpoint.pth.tar"

if not model_path.exists():
model_path.parent.mkdir(exist_ok=True, parents=True)
cached_file = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/COTR/{}/{}".format(
conf["weights"], "checkpoint.pth.tar"
),
)
logger.info("Downloaded COTR model succeeed!")
cmd = [
"cp",
str(cached_file),
str(model_path.parent),
]
subprocess.run(cmd, check=True)
logger.info(f"Copy model file `{cmd}`.")

opt.load_weights_path = str(model_path)
layer_2_channels = {
"layer1": 256,
"layer2": 512,
Expand Down
25 changes: 17 additions & 8 deletions hloc/matchers/gim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import sys
from pathlib import Path

import gdown
import torch
from huggingface_hub import hf_hub_download

from .. import logger
from ..utils.base_model import BaseModel
Expand All @@ -26,27 +26,36 @@ class GIM(BaseModel):
"image0",
"image1",
]
model_list = ["gim_lightglue_100h.ckpt", "gim_dkm_100h.ckpt"]
model_dict = {
"gim_lightglue_100h.ckpt": "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt",
"gim_dkm_100h.ckpt": "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view",
}

def _init(self, conf):
conf["model_name"] = str(conf["weights"])
if conf["model_name"] not in self.model_dict:
if conf["model_name"] not in self.model_list:
raise ValueError(f"Unknown GIM model {conf['model_name']}.")
model_path = conf["checkpoint_dir"] / conf["model_name"]

# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
model_link = self.model_dict[conf["model_name"]]
if "drive.google.com" in model_link:
gdown.download(model_link, output=str(model_path), fuzzy=True)
else:
cmd = ["wget", "--quiet", model_link, "-O", str(model_path)]
subprocess.run(cmd, check=True)
cached_file = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/gim/weights/{}".format(
conf["model_name"]
),
)
logger.info("Downloaded GIM model succeeed!")
cmd = [
"cp",
str(cached_file),
str(conf["checkpoint_dir"]),
]
subprocess.run(cmd, check=True)
logger.info(f"Copy model file `{cmd}`.")

self.aspect_ratio = 896 / 672
model = DKMv3(None, 672, 896, upsample_preds=True)
Expand Down

0 comments on commit e4a103b

Please sign in to comment.