-
Notifications
You must be signed in to change notification settings - Fork 29
/
inference.py
113 lines (94 loc) · 4.15 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
import einops
import argparse
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from depthfm import DepthFM
import matplotlib.pyplot as plt
def get_dtype_from_str(dtype_str):
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
def resize_max_res(
img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`Image.Image`):
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`Image.Image`: Resized image.
"""
original_width, original_height = img.size
downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
new_width = round(new_width / 64) * 64
new_height = round(new_height / 64) * 64
print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}")
resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img, (original_width, original_height)
def load_im(fp, processing_res=-1):
assert os.path.exists(fp), f"File not found: {fp}"
im = Image.open(fp).convert('RGB')
if processing_res < 0:
processing_res = max(im.size)
im, orig_res = resize_max_res(im, processing_res)
x = np.array(im)
x = einops.rearrange(x, 'h w c -> c h w')
x = x / 127.5 - 1
x = torch.tensor(x, dtype=torch.float32)[None]
return x, orig_res
def main(args):
print(f"{'Input':<10}: {args.img}")
print(f"{'Steps':<10}: {args.num_steps}")
print(f"{'Ensemble':<10}: {args.ensemble_size}")
# Load the model
model = DepthFM(args.ckpt)
model.cuda(args.device).eval()
# Load an image
im, orig_res = load_im(args.img, args.processing_res)
im = im.cuda(args.device)
# Generate depth
dtype = get_dtype_from_str(args.dtype)
model.model.dtype = dtype
with torch.autocast(device_type="cuda", dtype=dtype):
depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size)
depth = depth.squeeze(0).squeeze(0).cpu().numpy() # (h, w) in [0, 1]
# Convert depth to [0, 255] range
if args.no_color:
depth = (depth * 255).astype(np.uint8)
else:
depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3]
# Save the depth map
depth_fp = args.img + '_depth.png'
depth_img = Image.fromarray(depth)
if depth_img.size != orig_res:
depth_img = depth_img.resize(orig_res, Resampling.BILINEAR)
depth_img.save(depth_fp)
print(f"==> Saved depth map to {depth_fp}")
if __name__ == "__main__":
parser = argparse.ArgumentParser("DepthFM Inference")
parser.add_argument("--img", type=str, default="assets/dog.png",
help="Path to the input image")
parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt",
help="Path to the model checkpoint")
parser.add_argument("--num_steps", type=int, default=2,
help="Number of steps for ODE solver")
parser.add_argument("--ensemble_size", type=int, default=4,
help="Number of ensemble members")
parser.add_argument("--no_color", action="store_true",
help="If set, the depth map will be grayscale")
parser.add_argument("--device", type=int, default=0,
help="GPU to use")
parser.add_argument("--processing_res", type=int, default=-1,
help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.")
parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16",
help="Run with specific precision. Speeds up inference with subtle loss")
args = parser.parse_args()
main(args)