forked from TRI-ML/vidar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
71 lines (55 loc) · 2.5 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
dependencies = ["torch"]
import urllib
import torch
import torch.nn.functional as F
from vidar.core.evaluator import Evaluator
from vidar.utils.config import read_config
from vidar.utils.setup import setup_arch
def DeFiNe(pretrained=True, **kwargs):
"""
DeFiNe model for monocular depth estimation
pretrained (bool): load pretrained weights into model
Usage:
batch = {}
batch["rgb"] = # a list of images as 13HW torch.tensors
batch["intrinsics"] = # a list of 133 torch.tensor intrinsics matrices (one for each image)
batch["pose"] = # a batch of 144 relative poses to reference frame (one will be identity)
depth_preds = define_model(batch) # list of depths, one for each image
"""
cfg_url = "https://raw.githubusercontent.com/IgorVasiljevic-TRI/vidar/main/configs/papers/define/hub_define_temporal.yaml"
cfg = urllib.request.urlretrieve(cfg_url, "define_config.yaml")
cfg = read_config("define_config.yaml")
model = Evaluator(cfg)
if pretrained:
url = "https://tri-ml-public.s3.amazonaws.com/github/vidar/models/define_temporal.ckpt"
model = setup_arch(cfg.arch, verbose=True)
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
prefix = "module."
n_clip = len(prefix)
adapted_dict = {
k[n_clip:]: v
for k, v in state_dict["state_dict"].items()
if k.startswith(prefix)
}
model.load_state_dict(adapted_dict, strict=False)
return model
def PackNet(pretrained=True, **kwargs):
"""
PackNet model for monocular depth estimation
pretrained (bool): load pretrained weights into model
Usage:
model = torch.hub.load("TRI-ML/vidar", "PackNet", pretrained=True)
rgb_image = torch.rand(1, 3, H, W)
depth_pred = model(rgb_image)
"""
cfg_url = "https://raw.githubusercontent.com/TRI-ML/vidar/main/configs/papers/packnet/hub_packnet.yaml"
cfg = urllib.request.urlretrieve(cfg_url, "packnet_config.yaml")
cfg = read_config("packnet_config.yaml")
model = Evaluator(cfg)
if pretrained:
url = "https://tri-ml-public.s3.amazonaws.com/github/vidar/models/PackNet_MR_selfsup_KITTI.ckpt"
model = setup_arch(cfg.arch, verbose=True)
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.load_state_dict(state_dict["state_dict"], strict=False)
return model