Skip to content

Commit

Permalink
Add Triplane model
Browse files Browse the repository at this point in the history
++ Temporary version
  • Loading branch information
benhenryL committed Nov 17, 2022
1 parent 6727731 commit 333fffb
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 19 deletions.
231 changes: 231 additions & 0 deletions TensoRF/models/tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,237 @@ def TV_loss_app(self, reg):
return total


class TriPlane(TensorBase):
def __init__(self, aabb, gridSize, device,
use_mask=False, use_dwt=False, dwt_level=2,
trans_func='bior4.4', **kargs):
super(TriPlane, self).__init__(aabb, gridSize, device, **kargs)
self.use_mask = use_mask
self.use_dwt = use_dwt
self.dwt_level = dwt_level
self.trans_func = trans_func

if use_mask:
self.init_mask()

def get_kwargs(self):
return {
'aabb': self.aabb,
'gridSize':self.gridSize.tolist(),
'density_n_comp': self.density_n_comp,
'appearance_n_comp': self.app_n_comp,
'app_dim': self.app_dim,

'density_shift': self.density_shift,
'alphaMask_thres': self.alphaMask_thres,
'distance_scale': self.distance_scale,
'rayMarch_weight_thres': self.rayMarch_weight_thres,
'fea2denseAct': self.fea2denseAct,

'near_far': self.near_far,
'step_ratio': self.step_ratio,

'shadingMode': self.shadingMode,
'pos_pe': self.pos_pe,
'view_pe': self.view_pe,
'fea_pe': self.fea_pe,
'featureC': self.featureC,

'grid_bit': self.grid_bit,
'use_mask': self.use_mask,
'use_dwt': self.use_dwt,
'dwt_level': self.dwt_level,
'trans_func': self.trans_func,
}

def init_svd_volume(self, res, device):
self.density_plane= self.init_one_svd(
self.density_n_comp, self.gridSize, 0.1, device)
self.app_plane = self.init_one_svd(
self.app_n_comp, self.gridSize, 0.1, device)
self.basis_mat = nn.Linear(
sum(self.app_n_comp), self.app_dim, bias=False).to(device)

@torch.no_grad()
def init_mask(self):
self.density_plane_mask = nn.ParameterList(
[nn.Parameter(torch.ones_like(self.density_plane[i]))
for i in range(3)])
self.app_plane_mask = nn.ParameterList(
[nn.Parameter(torch.ones_like(self.app_plane[i]))
for i in range(3)])

def init_one_svd(self, n_component, gridSize, scale, device):
plane_coef = []
for i in range(len(self.matMode)):
mat_id_0, mat_id_1 = self.matMode[i]
plane_coef.append(nn.Parameter(
scale * torch.randn((1, n_component[i], gridSize[mat_id_1],
gridSize[mat_id_0]))))

return nn.ParameterList(plane_coef).to(device)

def get_optparam_groups(self, lr0=0.02, lr1=0.001):
grad_vars = [{'params': self.density_plane, 'lr': lr0},
{'params': self.app_plane, 'lr': lr0},
{'params': self.basis_mat.parameters(), 'lr':lr1}]

if isinstance(self.renderModule, nn.Module):
grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr1}]

if self.use_mask:
grad_vars += [{'params': self.density_plane_mask, 'lr': lr0},
{'params': self.app_plane_mask, 'lr': lr0}]

return grad_vars

def compute_densityfeature(self, points):
# plane + line basis
# [3, B, 1, 2]
coordinate_plane = points[..., self.matMode].transpose(0, -2) \
.view(3, -1, 1, 2)

sigma_feature = torch.zeros((points.shape[0],), device=points.device)

for idx in range(len(self.density_plane)):
plane = min_max_quantize(self.density_plane[idx], self.grid_bit)

if self.use_mask:
mask = torch.sigmoid(self.density_plane_mask[idx])
plane = (plane * (mask >= 0.5) - plane * mask).detach() \
+ plane * mask

if self.use_dwt:
plane = inverse(plane, self.dwt_level, self.trans_func)

plane_coef_point = F.grid_sample(
plane, coordinate_plane[[idx]],
align_corners=True).view(-1, *points.shape[:1])

sigma_feature += torch.sum(plane_coef_point, dim=0)

return sigma_feature

def compute_appfeature(self, points):
# plane + line basis
# [3, B, 1, 2]
coordinate_plane = points[..., self.matMode].transpose(0, -2) \
.view(3, -1, 1, 2)

plane_coef_point = []
for idx in range(len(self.app_plane)):
plane = min_max_quantize(self.app_plane[idx], self.grid_bit)

if self.use_mask:
mask = torch.sigmoid(self.app_plane_mask[idx])
plane = (plane * (mask >= 0.5) - plane * mask).detach() \
+ plane * mask

if self.use_dwt:
plane = inverse(plane, self.dwt_level, self.trans_func)

plane_coef_point.append(F.grid_sample(
plane, coordinate_plane[[idx]],
align_corners=True).view(-1, *points.shape[:1]))

plane_coef_point = torch.cat(plane_coef_point)

return self.basis_mat(plane_coef_point.T)

@torch.no_grad()
def upsample_volume_grid(self, res_target):
self.app_plane = self.up_sampling_VM(self.app_plane, res_target)
self.density_plane = self.up_sampling_VM(self.density_plane, res_target)

if self.use_mask:
self.app_plane_mask = self.up_sampling_VM(self.app_plane_mask, res_target)
self.density_plane_mask = self.up_sampling_VM(self.density_plane_mask, res_target)

self.update_stepSize(res_target)
print(f'upsamping to {res_target}')

@torch.no_grad()
def up_sampling_VM(self, plane_coef, res_target):
for i in range(len(self.matMode)):
mat_id_0, mat_id_1 = self.matMode[i]

if self.use_dwt:
plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level, self.trans_func))

plane_coef[i] = nn.Parameter(
F.interpolate(plane_coef[i].data,
size=(res_target[mat_id_1], res_target[mat_id_0]),
mode='bilinear', align_corners=True))

if self.use_dwt:
plane_coef[i].set_(forward(plane_coef[i], self.dwt_level, self.trans_func))

return plane_coef

@torch.no_grad()
def shrink(self, new_aabb):
print("====> shrinking ...")
unit = 16 # unit for DWT

for i in range(len(self.matMode)):
# Planes
mode0, mode1 = self.matMode[i]
if self.use_dwt:
self.density_plane[i].set_(inverse(self.density_plane[i],
self.dwt_level, self.trans_func))
self.app_plane[i].set_(inverse(self.app_plane[i],
self.dwt_level, self.trans_func))

steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0]
steps = int(steps / unit) * unit
grid0 = torch.linspace(new_aabb[0][mode0], new_aabb[1][mode0],
steps).to(self.density_plane[i].device)

steps = (new_aabb[1][mode1]-new_aabb[0][mode1]) / self.units[mode1]
steps = int(steps / unit) * unit
grid1 = torch.linspace(new_aabb[0][mode1], new_aabb[1][mode1],
steps).to(self.density_plane[i].device)
grid = torch.stack(torch.meshgrid(grid0, grid1), -1).unsqueeze(0)

self.density_plane[i] = nn.Parameter(
F.grid_sample(self.density_plane[i], grid, align_corners=True))
self.app_plane[i] = nn.Parameter(
F.grid_sample(self.app_plane[i], grid, align_corners=True))

if self.use_dwt:
self.density_plane[i].set_(forward(self.density_plane[i],
self.dwt_level, self.trans_func))
self.app_plane[i].set_(forward(self.app_plane[i],
self.dwt_level, self.trans_func))

self.aabb = new_aabb

Y, X = self.density_plane[0].shape[-2:]
Z = self.density_plane[1].shape[-2]
self.update_stepSize((X,Y,Z))

if self.use_mask:
self.init_mask()

def density_L1(self):
total = 0
for idx in range(len(self.density_plane)):
total = total + torch.mean(torch.abs(self.density_plane[idx]))
return total

def TV_loss_density(self, reg):
total = 0
for idx in range(len(self.density_plane)):
total = total + reg(self.density_plane[idx]) * 1e-2
return total

def TV_loss_app(self, reg):
total = 0
for idx in range(len(self.app_plane)):
total = total + reg(self.app_plane[idx]) * 1e-2
return total


class TensorCP(TensorBase):
def __init__(self, aabb, gridSize, device, **kargs):
super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)
Expand Down
2 changes: 1 addition & 1 deletion TensoRF/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def config_parser(cmd=None):
parser.add_argument('--downsample_test', type=float, default=1.0)

parser.add_argument('--model_name', type=str, default='TensorVMSplit',
choices=['TensorVMSplit', 'TensorCP'])
choices=['TensorVMSplit', 'TensorCP', 'TriPlane'])

# loader options
parser.add_argument("--batch_size", type=int, default=4096)
Expand Down
2 changes: 1 addition & 1 deletion TensoRF/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataLoader.ray_utils import get_rays
from dataLoader.ray_utils import ndc_rays_blender
from models.tensoRF import TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask
from models.tensoRF import TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask, TriPlane
from utils import *


Expand Down
51 changes: 34 additions & 17 deletions TensoRF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def render_test(args):
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)

_, _, Z, Y, X = tensorf.alphaMask.alpha_volume.shape
tensorf.alphaMask = None
tensorf.alpha_offset = 0
tensorf.updateAlphaMask((X,Y,Z))

logfolder = os.path.dirname(args.ckpt)
if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
Expand All @@ -85,8 +90,9 @@ def render_test(args):

if args.render_test:
os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')

if args.render_path:
c2ws = test_dataset.render_path
Expand Down Expand Up @@ -222,10 +228,11 @@ def reconstruction(args):
mask_loss = sum([p.sum()
for p in tensorf.density_plane_mask.parameters()])\
+ sum([p.sum()
for p in tensorf.app_plane_mask.parameters()])
if hasattr(tensorf, "density_line_mask"):
mask_loss += sum([p.sum()
for p in tensorf.density_line_mask.parameters()])\
+ sum([p.sum()
for p in tensorf.app_plane_mask.parameters()])\
+ sum([p.sum()
+ sum([p.sum()
for p in tensorf.app_line_mask.parameters()])
total_loss = total_loss + args.mask_weight * mask_loss

Expand Down Expand Up @@ -306,16 +313,18 @@ def reconstruction(args):
with torch.no_grad():
for i in range(3):
tensorf.density_plane[i].set_(min_max_quantize(tensorf.density_plane[i], args.grid_bit) * (tensorf.density_plane_mask[i] >= 0))
tensorf.density_line[i].set_(min_max_quantize(tensorf.density_line[i], args.grid_bit) * (tensorf.density_line_mask[i] >= 0))
tensorf.app_plane[i].set_(min_max_quantize(tensorf.app_plane[i], args.grid_bit) * (tensorf.app_plane_mask[i] >= 0))
tensorf.app_line[i].set_(min_max_quantize(tensorf.app_line[i], args.grid_bit) * (tensorf.app_line_mask[i] >= 0))
if hasattr(tensorf, "density_line_mask"):
tensorf.density_line[i].set_(min_max_quantize(tensorf.density_line[i], args.grid_bit) * (tensorf.density_line_mask[i] >= 0))
tensorf.app_line[i].set_(min_max_quantize(tensorf.app_line[i], args.grid_bit) * (tensorf.app_line_mask[i] >= 0))

tensorf.use_mask = False

del tensorf.density_plane_mask
del tensorf.density_line_mask
del tensorf.app_plane_mask
del tensorf.app_line_mask
if hasattr(tensorf, "density_line_mask"):
del tensorf.density_line_mask
del tensorf.app_line_mask

grid, non_grid = tensorf_param_count(tensorf)
grid_bytes = grid * args.grid_bit / 8
Expand All @@ -325,27 +334,35 @@ def reconstruction(args):
f'(N: {non_grid_bytes/1_048_576:3f}MB)')

if args.use_mask:
flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit),
min_max_quantize(p[1].flatten(), args.grid_bit),
min_max_quantize(p[2].flatten(), args.grid_bit)])
for p in [tensorf.density_plane,
tensorf.density_line,
tensorf.app_plane,
tensorf.app_line]])
if hasattr(tensorf, "density_line"):
flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit),
min_max_quantize(p[1].flatten(), args.grid_bit),
min_max_quantize(p[2].flatten(), args.grid_bit)])
for p in [tensorf.density_plane,
tensorf.density_line,
tensorf.app_plane,
tensorf.app_line]])
else:
flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit),
min_max_quantize(p[1].flatten(), args.grid_bit),
min_max_quantize(p[2].flatten(), args.grid_bit)])
for p in [tensorf.density_plane,
tensorf.app_plane]])

ratio = (flat_mask != 0).float().mean()
print(f'non-masked ratio: {ratio:.4f}')
grid_bytes = grid_bytes * ratio
print(f'masked_total: {(grid_bytes + non_grid_bytes)/1_048_576:.3f}MB '
f'(G ({args.grid_bit}bit): {grid_bytes/1_048_576:.3f}MB) '
f'(N: {non_grid_bytes/1_048_576:3f}MB)')

tensorf.save(f'{logfolder}/{args.expname}.th')
# Alpha mask reconstruction
_, _, Z, Y, X = tensorf.alphaMask.alpha_volume.shape
tensorf.alphaMask = None
tensorf.alpha_offset = 0
tensorf.updateAlphaMask((X,Y,Z))

tensorf.save(f'{logfolder}/{args.expname}.th')

if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
Expand Down Expand Up @@ -395,4 +412,4 @@ def reconstruction(args):
render_test(args)
else:
reconstruction(args)

0 comments on commit 333fffb

Please sign in to comment.