Skip to content

Commit

Permalink
fix error (renderer evaluation issue)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Oct 17, 2022
1 parent 84ccb73 commit 6db9196
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 75 deletions.
36 changes: 19 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def main(args):
'lr': 1e-3},
{'params': appearance_net.second_stage.parameters(), 'lr': 1e-3}],
betas=(0.9, 0.99))
scheduler = get_cos_warmup_scheduler(optimizer, args.n_iters, 0)
scaler = torch.cuda.amp.GradScaler()
scheduler = get_cos_warmup_scheduler(optimizer, args.n_iters, 0,
min_ratio=0.1)
# scaler = torch.cuda.amp.GradScaler()

pbar = tqdm(range(args.n_iters))
for i in pbar:
Expand All @@ -109,24 +110,24 @@ def main(args):

optimizer.zero_grad()

with torch.cuda.amp.autocast(enabled=True):
rgb_map, depth_map = renderer(rays_train)
# with torch.cuda.amp.autocast(enabled=True):
rgb_map, depth_map = renderer(rays_train)

loss = F.mse_loss(rgb_map, rgb_train)
loss = F.mse_loss(rgb_map, rgb_train)

# loss
total_loss = loss
if args.tv_weight > 0:
total_loss += renderer.compute_tv() * args.tv_weight
# loss
total_loss = loss
if args.tv_weight > 0:
total_loss += renderer.compute_tv() * args.tv_weight

assert not torch.isnan(loss)
scaler.scale(total_loss).backward(retain_graph=True)
assert not torch.isnan(loss)
# scaler.scale(total_loss).backward(retain_graph=True)

# total_loss.backward()
# optimizer.step()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
total_loss.backward()
optimizer.step()
# scaler.unscale_(optimizer)
# scaler.step(optimizer)
# scaler.update()

scheduler.step()

Expand All @@ -136,7 +137,8 @@ def main(args):
if i + 1 in [500, 1000, 2500, 5000, 10000, 20000]:
print()

PSNRs_test = rendere.evaluation(test_dataset)
# Evaluation
PSNRs_test = renderer.evaluation(test_dataset)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} '
f'<========================')

Expand Down
100 changes: 55 additions & 45 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from .modules import get_activation

Expand Down Expand Up @@ -52,7 +54,7 @@ def __init__(self, main_net: nn.Module, appearance_net=None,
near=2, far=7, white_bg=True,
use_alpha=True, min_alpha_requirement=1e-4,
normalize_coords=True,
density_scale=25,
density_scale=1, # 25,
density_activation=None, appearance_activation=None):
super().__init__()
"""
Expand Down Expand Up @@ -91,6 +93,7 @@ def forward(self, rays, batch_size=None):
outputs = []
for rays_minibatch in torch.split(rays, batch_size):
outputs.append(self.render(rays_minibatch))

return list(map(torch.cat, zip(*outputs)))

def render(self, rays):
Expand Down Expand Up @@ -150,7 +153,7 @@ def render(self, rays):
rgb[app_mask] += self.appearance_activation(
self.appearance_net(pts[app_mask], viewdirs[app_mask]))
else:
outs = self.main_net(pts[valid_rays])
outs = self.main_net(pts[valid_rays], viewdirs[valid_rays])
sigma = self.density_activation(outs[..., 0] * self.density_scale)
rgb = self.appearance_activation(outs[..., 1:])

Expand All @@ -162,6 +165,7 @@ def render(self, rays):

acc_map = torch.sum(weights, -1)
rgb_map = torch.sum(weights[..., None] * rgb, -2)

if self.white_bg:
rgb_map = rgb_map + (1. - acc_map[..., None])
rgb_map = rgb_map.clamp(min=0, max=1)
Expand All @@ -170,15 +174,14 @@ def render(self, rays):

return rgb_map, depth_map

def evaluation(self, test_dataset, save_path=None,
def evaluation(self, test_dataset, batch_size=4096, save_path=None,
compute_extra_metrics=True):
self.eval()

PSNRs, rgb_maps, depth_maps = [], [], []

if save_path is not None:
import imageio
import os
os.makedirs(save_path, exist_ok=True)
os.makedirs(os.path.join(save_path, "rgbd"), exist_ok=True)

Expand All @@ -195,65 +198,72 @@ def evaluation(self, test_dataset, save_path=None,
except Exception:
gt_exist = False

for idx in range(len(test_dataset)):
W, H = test_dataset.img_wh
with torch.no_grad():
for idx in tqdm.tqdm(range(len(test_dataset))):
W, H = test_dataset.img_wh

rays = test_dataset.all_rays[idx]
rays = rays.view(-1, rays.shape[-1]).cuda(non_blocking=True)

rgb_map, depth_map = self.forward(rays, batch_size=batch_size)

rays = test_dataset.all_rays[idx]
rays = rays.view(-1, rays.shape[-1]).cuda(non_blocking=True)
rgb_map = rgb_map.reshape(H, W, 3)
depth_map = depth_map.reshape(H, W)

with torch.no_grad():
rgb_map, depth_map = self.forward(rays)
if gt_exist:
gt_rgb = test_dataset.all_rgbs[idx].view(H, W, 3) \
.cuda(non_blocking=True)

rgb_map = rgb_map.reshape(H, W, 3)
depth_map = depth_map.reshape(H, W)
loss = F.mse_loss(rgb_map, gt_rgb)
PSNRs.append(-10.0 * torch.log(loss) / np.log(10.0))

if gt_exist:
gt_rgb = test_dataset.all_rgbs[idx].view(H, W, 3) \
.cuda(non_blocking=True)
if compute_extra_metrics:
gt = gt_rgb.permute([2, 0, 1]).contiguous()
im = rgb_map.permute([2, 0, 1]).contiguous()

loss = F.mse_loss(rgb_map, gt_rgb)
PSNRs.append(-10.0 * torch.log(loss) / np.log(10.0))
ssims.append(ssim(im[None], gt[None], data_range=1))
l_alex.append(lpips_alex(gt, im, normalize=True))
l_vgg.append(lpips_vgg(gt, im, normalize=True))

if compute_extra_metrics:
gt = gt_rgb.permute([2, 0, 1]).contiguous()
im = rgb_map.permute([2, 0, 1]).contiguous()
del gt_rgb

ssims.append(ssim(im[None], gt[None], data_range=1))
l_alex.append(lpips_alex(gt, im, normalize=True))
l_vgg.append(lpips_vgg(gt, im, normalize=True))
rgb_map = (rgb_map * 255).int()
rgb_maps.append(rgb_map.cpu())
depth_maps.append(depth_map.cpu())

rgb_map = (rgb_map * 255).int()
rgb_maps.append(rgb_map)
depth_maps.append(depth_map)
if save_path is not None:
imageio.imwrite(os.path.join(save_path, f'{idx:03d}.png'),
rgb_map.cpu().numpy())
rgb_map = torch.concat((rgb_map, depth_map), axis=1)
imageio.imwrite(os.path.join(save_path,
f'rgbd/{idx:03d}.png'),
rgb_map.cpu().numpy())

if save_path is not None:
imageio.imwrite(os.path.join(save_path, f'{idx:03d}.png'),
rgb_map.cpu().numpy())
rgb_map = torch.concat((rgb_map, depth_map), axis=1)
imageio.imwrite(os.path.join(save_path, f'rgbd/{idx:03d}.png'),
rgb_map.cpu().numpy())
del rays, rgb_map, depth_map

if save_path is not None:
imageio.mimwrite(os.path.join(save_path, f'video.mp4'),
torch.stack(rgb_maps).cpu().numpy(),
torch.stack(rgb_maps).numpy(),
fps=30, quality=10)
imageio.mimwrite(os.path.join(save_path, f'depthvideo.mp4'),
torch.stack(depth_maps).cpu().numpy(),
torch.stack(depth_maps).numpy(),
fps=30, quality=10)

if gt_exists:
psnr = torch.cat(PSNRs).mean().cpu().numpy()
if gt_exist:
psnr = torch.stack(PSNRs).mean().cpu().numpy()

if compute_extra_metrics:
avg_ssim = torch.cat(ssims).mean().cpu().numpy()
avg_l_a = torch.cat(l_alex).mean().cpu().numpy()
avg_l_v = torch.cat(l_vgg).mean().cpu().numpy()
print(f'ssim: {avg_ssim}, LPIPS(alexnet): {avg_l_a}, '
f'LPIPS(vgg): {avg_l_v}')
np.savetxt(f'{save_path}/{prtx}mean.txt',
np.asarray([psnr, avg_ssim, avg_l_a, avg_l_v]))
else:
np.savetxt(f'{save_path}/{prtx}mean.txt', np.asarray([psnr]))
avg_ssim = torch.stack(ssims).mean().cpu().numpy()
avg_l_a = torch.stack(l_alex).mean().cpu().numpy()
avg_l_v = torch.stack(l_vgg).mean().cpu().numpy()
print(f'ssim: {avg_ssim:.4f}, LPIPS(alexnet): {avg_l_a:.4f}, '
f'LPIPS(vgg): {avg_l_v:.4f}')
if save_path is not None:
np.savetxt(os.path.join(save_path, 'mean.txt'),
np.asarray([psnr, avg_ssim, avg_l_a, avg_l_v]))
elif save_path is not None:
np.savetxt(os.path.join(save_path, 'mean.txt'),
np.asarray([psnr]))

self.train()

Expand Down
22 changes: 9 additions & 13 deletions models/grid_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn.functional as F

import models.cosine_transform as ct
from .modules import positional_encoding


class FreqGrid(nn.Module):
Expand Down Expand Up @@ -167,39 +168,34 @@ def __init__(self, resolution: int, n_chan: int, out_dim=1):
1e-1 * torch.randn(3, n_chan, 1, resolution),
requires_grad=True)

if out_dim > 1:
self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False)
else:
self.basis_mat = None
self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False)

def forward(self, coords, *args, **kwargs):
# [B, 3] to [1, B, 1, 3]
coords = coords.reshape(1, -1, 1, coords.shape[-1])

# features from planes
grid = self.planes
p_feats = F.grid_sample(grid, torch.cat([coords[..., (1, 2)],
feats0 = F.grid_sample(grid, torch.cat([coords[..., (1, 2)],
coords[..., (0, 2)],
coords[..., (0, 1)]], 0),
mode='bilinear',
padding_mode='zeros', align_corners=True)
p_feats = p_feats.squeeze(-1).permute(2, 1, 0) # [B, C, 3]
feats0 = feats0.squeeze(-1).permute(2, 1, 0) # [B, C, 3]

# features from Vectors
grid = self.vectors
v_feats = F.grid_sample(grid, F.pad(torch.cat([coords[..., (0,)],
feats1 = F.grid_sample(grid, F.pad(torch.cat([coords[..., (0,)],
coords[..., (1,)],
coords[..., (2,)]], 0),
(1, 0)),
mode='bilinear',
padding_mode='zeros', align_corners=True)
v_feats = v_feats.squeeze(-1).permute(2, 1, 0) # [B, C, 3]

features = (p_feats * v_feats).flatten(1, -1) # [B, C*3]
feats1 = feats1.squeeze(-1).permute(2, 1, 0) # [B, C, 3]

if self.basis_mat is not None:
return self.basis_mat(features).squeeze(-1)
return features.sum(-1)
feats0 = (feats0 * feats1).flatten(1, -1) # [B, C*3]
del feats1
return self.basis_mat(feats0).squeeze(-1)

def compute_tv(self):
return F.mse_loss(self.planes[..., 1:, :], self.planes[..., :-1, :]) \
Expand Down

0 comments on commit 6db9196

Please sign in to comment.