Skip to content

Latest commit

 

History

History
30 lines (25 loc) · 863 Bytes

09.pytorch-nerf模型训练4之损失函数.md

File metadata and controls

30 lines (25 loc) · 863 Bytes

颠覆传统三维重建方法之nerf(十四)---pytorch-nerf模型训练4之损失函数.md

rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                        verbose=i < 10, retraw=True,
                                        **render_kwargs_train)

optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)

if 'rgb0' in extras:
    img_loss0 = img2mse(extras['rgb0'], target_s)
    loss = loss + img_loss0
    psnr0 = mse2psnr(img_loss0)

loss.backward()
optimizer.step()

# NOTE: IMPORTANT!
###   update learning rate   ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
    param_group['lr'] = new_lrate