-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
154 lines (123 loc) · 6.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np
import os
import time
import torch
import tqdm
from nerf_shared import config_parser
from nerf_shared import utils
from torch.utils.tensorboard.writer import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
torch.manual_seed(0)
DEBUG = False
def run():
parser = config_parser.config_parser()
args = parser.parse_args()
if args.training is True:
#Loads dataset info like images and Ground Truth poses and camera intrinsics
images, poses, render_poses, hwf, i_split, K, bds_dict = utils.load_datasets(args)
# Train, val, test split
i_train, i_val, i_test = i_split
# Resolution (H, W) and focal length
H, W, focal = hwf
# Copy config file to log file
utils.copy_log_dir(args)
# Tensorboard Support
if args.tensorboard:
tbdir = os.path.join(args.basedir, args.expname, "tb_logs")
tb_writer = SummaryWriter(log_dir=tbdir)
else:
tb_writer = None
# Create coarse/fine NeRF models.
coarse_model, fine_model = utils.create_nerf_models(args)
# Create optimizer for trainable params.
optimizer = utils.get_optimizer(coarse_model, fine_model, args)
# Load any available checkpoints.
start = utils.load_checkpoint(coarse_model, fine_model, optimizer, args, b_load_ckpnt_as_trainable=True)
renderer = utils.get_renderer(args, bds_dict)
global_step = start
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)
# Batch the training data
images, poses, rays_rgb, use_batching, N_rand, i_batch = utils.batch_training_data(args, poses, hwf, K, images, i_train)
N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
start = start + 1
for i in tqdm.trange(start, N_iters):
renderer.train()
# Randomly select a batch of rays across images, or randomly sample from a single image per iteration
# determined by boolean use_batching
batch_rays, target_s, rays_rgb, i_batch = utils.sample_random_ray_batch(args, images, poses,
rays_rgb, N_rand, use_batching, i_batch, i_train,
hwf, K, start, i)
##### Core optimization loop #####
rgb, _, _, extras = renderer.render_from_rays(H,
W,
K,
chunk=args.chunk,
rays=batch_rays,
coarse_model=coarse_model,
fine_model=fine_model,
retraw=True)
optimizer.zero_grad()
#Mean squared error between rendered ray RGB vs. Ground Truth RGB using the fine model
img_loss = utils.img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = utils.mse2psnr(img_loss)
# If using both the coarse and fine model,
if 'rgb0' in extras:
# MSE loss between rendered coarse model and GT RGB
img_loss0 = utils.img2mse(extras['rgb0'], target_s)
#Add the coarse and fine reconstruction loss together
loss = loss + img_loss0
psnr0 = utils.mse2psnr(img_loss0)
# TODO(pculbert, chengine): Debug optimization; performance does not match
# original implementation.
loss.backward()
optimizer.step()
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################
# Logging
# Periodically saves weights
if i%args.i_weights==0:
utils.save_checkpoints(args, coarse_model, fine_model, optimizer, global_step, i)
'''
# Constructs a panoramic video of a camera within the NeRF scene
if i%args.i_video==0 and i > 0:
utils.render_training_video(args, render_poses, hwf, K, render_kwargs_test, i)
'''
# Renders out the test poses (i.e. poses[i_test]) to visually evaluate NeRF quality
if i%args.i_testset==0 and i > 0:
renderer.render_from_batch_poses(H,
W,
K,
chunk=args.chunk,
batch_c2w=poses[i_test],
coarse_model=coarse_model,
fine_model=fine_model,
retraw=True,
save_directory=os.path.join(args.basedir, args.expname, 'testset_{:06d}'.format(i)),
b_combine_as_video=False,
tb_writer=tb_writer)
#Displays loss and PSNR (Peak signal to noise ratio) of the fine reconstruction loss
if i%args.i_print==0:
utils.print_statistics(args, loss, psnr, i, tb_writer=tb_writer)
global_step += 1
else:
### Define Custom Functionality Here
pass
if __name__=='__main__':
if device.type != 'cpu':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.cuda.empty_cache()
run()