From 8111899abf90ec3c7eb8dc7a014e88f26a5e3209 Mon Sep 17 00:00:00 2001 From: daniel03c1 Date: Fri, 4 Nov 2022 12:14:52 +0900 Subject: [PATCH] update TensoRF (mask + dwt + quantization) --- TensoRF/configs/chair.txt | 28 ++-- TensoRF/models/dwt.py | 48 ++++++ TensoRF/models/tensoRF.py | 297 ++++++++++++++++------------------- TensoRF/models/tensorBase.py | 40 +++-- TensoRF/opt.py | 14 +- TensoRF/renderer.py | 10 +- TensoRF/train.py | 119 +++++++++----- TensoRF/utils.py | 10 +- 8 files changed, 320 insertions(+), 246 deletions(-) create mode 100644 TensoRF/models/dwt.py diff --git a/TensoRF/configs/chair.txt b/TensoRF/configs/chair.txt index 4eaf055..ba46e04 100644 --- a/TensoRF/configs/chair.txt +++ b/TensoRF/configs/chair.txt @@ -1,6 +1,5 @@ - dataset_name = blender -datadir = ../nerf_synthetic/chair +datadir = ../../nerf_synthetic/chair expname = tensorf_lego_VM basedir = ./log @@ -15,30 +14,33 @@ update_AlphaMask_list = [2000, 4000] N_vis = 5 vis_every = 10000 -# lr_init = 0.005 # 0.001 # 0.5 # 0.02 # test +# lr_init = 0.01 # 0.001 # 0.5 # 0.02 # test # lr_basis = 0.005 # 0.001 # 0.02 # 0.001 # test render_test = 1 -n_lamb_sigma = [16, 16, 16] -n_lamb_sh = [48, 48, 48] +n_lamb_sigma = [16, 16, 16] # 3, 3, 3] # 16, 16, 16] +n_lamb_sh = [48, 48, 48] # 6, 6, 6] # 48, 48, 48] model_name = TensorVMSplit shadingMode = MLP_Fea fea2denseAct = softplus -view_pe = 2 -fea_pe = 2 +pos_pe = 0 # 6 # None +view_pe = 2 # 3 # 2 +fea_pe = 2 # 7 # 3 # 2 +featureC = 128 # 116 # 128 +# data_dim_color = 64 # 22 # 8 # 2 L1_weight_inital = 0 # 8e-5 L1_weight_rest = 0 # 4e-5 rm_weight_mask_thre = 1e-4 ## please uncomment following configuration if hope to training on cp model -#model_name = TensorCP -#n_lamb_sigma = [96] -#n_lamb_sh = [288] -#N_voxel_final = 125000000 # 500**3 -#L1_weight_inital = 1e-5 -#L1_weight_rest = 1e-5 +# model_name = TensorCP +# n_lamb_sigma = [96] +# n_lamb_sh = [288] +# N_voxel_final = 125000000 # 500**3 +# L1_weight_inital = 1e-5 +# L1_weight_rest = 1e-5 diff --git a/TensoRF/models/dwt.py b/TensoRF/models/dwt.py new file mode 100644 index 0000000..43c783c --- /dev/null +++ b/TensoRF/models/dwt.py @@ -0,0 +1,48 @@ +import torch +from pytorch_wavelets import DWTInverse, DWTForward + + +def inverse(inputs, level=4): + assert inputs.size(-1) % 2**level == 0 + assert inputs.size(-2) % 2**level == 0 + + res0, res1 = inputs.shape[-2:] + + yl = inputs[..., :res0//(2**level), :res1//(2**level)] + + yh = [ + torch.stack([inputs[..., :res0//(2**(i+1)), + res1//(2**(i+1)):res1//(2**i)], + inputs[..., res0//(2**(i+1)):res0//(2**i), + :res1//(2**(i+1))], + inputs[..., res0//(2**(i+1)):res0//(2**i), + res1//(2**(i+1)):res1//(2**i)]], 2)/(level-i+1) + for i in range(level) + ] + + return DWTInverse(wave='bior4.4', + mode='periodization').to(inputs.device)((yl, yh)) + + +def forward(inputs, level=4): + assert inputs.size(-1) % 2**level == 0 + assert inputs.size(-2) % 2**level == 0 + + yl, yh = DWTForward(wave='bior4.4', J=level, + mode='periodization').to(inputs.device)(inputs) + outs = yl + + for i in range(level): + cf = yh[-i-1] * (i+2) + outs = torch.cat([torch.cat([outs, cf[..., 0, :, :]], -1), + torch.cat([cf[..., 1, :, :], cf[..., 2, :, :]], -1)], + -2) + return outs + + +if __name__ == '__main__': + a = torch.randn(3, 5, 64, 80).cuda() * 10 + print(a.shape, inverse(a).shape) + print((a - forward(inverse(a))).abs().max()) + print((a - inverse(forward(a))).abs().max()) + diff --git a/TensoRF/models/tensoRF.py b/TensoRF/models/tensoRF.py index e237b4e..aec1048 100644 --- a/TensoRF/models/tensoRF.py +++ b/TensoRF/models/tensoRF.py @@ -1,4 +1,9 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + from .tensorBase import * +from .dwt import forward, inverse def min_max_quantize(inputs, bits): @@ -8,155 +13,78 @@ def min_max_quantize(inputs, bits): # rounding min_value = torch.amin(inputs) max_value = torch.amax(inputs) - scale = (max_value - min_value).clamp(min=1e-8) / (bits ** 2 - 1) + scale = (max_value - min_value).clamp(min=1e-8) / (2 ** bits - 1) rounded = torch.round((inputs - min_value) / scale) * scale + min_value return (rounded - inputs).detach() + inputs -class TensorVM(TensorBase): - def __init__(self, aabb, gridSize, device, **kargs): - super(TensorVM, self).__init__(aabb, gridSize, device, **kargs) - - def init_svd_volume(self, res, device): - self.plane_coef = torch.nn.Parameter( - 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device)) - self.line_coef = torch.nn.Parameter( - 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device)) - self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device) - - def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): - grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz}, {'params': self.plane_coef, 'lr': lr_init_spatialxyz}, - {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): - grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] - return grad_vars - - def compute_features(self, xyz_sampled): - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach() - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach() - - plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - - sigma_feature = torch.sum(plane_feats * line_feats, dim=0) - - plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1) - line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1) - - app_features = self.basis_mat((plane_feats * line_feats).T) - - return sigma_feature, app_features - - def compute_densityfeature(self, xyz_sampled): - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - - sigma_feature = torch.sum(plane_feats * line_feats, dim=0) - - return sigma_feature - - def compute_appfeature(self, xyz_sampled): - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1) - line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1) - - app_features = self.basis_mat((plane_feats * line_feats).T) - - return app_features - - def vectorDiffs(self, vector_comps): - total = 0 - - for idx in range(len(vector_comps)): - # print(self.line_coef.shape, vector_comps[idx].shape) - n_comp, n_size = vector_comps[idx].shape[:-1] - - dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2)) - # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape) - non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1] - # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape) - total = total + torch.mean(torch.abs(non_diagonal)) - return total - - def vector_comp_diffs(self): - return self.vectorDiffs(self.line_coef[:,-self.density_n_comp:]) + self.vectorDiffs(self.line_coef[:,:self.app_n_comp]) - - @torch.no_grad() - def up_sampling_VM(self, plane_coef, line_coef, res_target): - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - mat_id_0, mat_id_1 = self.matMode[i] - - plane_coef[i] = torch.nn.Parameter( - F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', - align_corners=True)) - line_coef[i] = torch.nn.Parameter( - F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - - return plane_coef, line_coef - - @torch.no_grad() - def upsample_volume_grid(self, res_target): - scale = res_target[0]/self.line_coef.shape[2] #assuming xyz have the same scale - plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',align_corners=True) - line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0],1), mode='bilinear',align_corners=True) - self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef) - self.compute_stepSize(res_target) - print(f'upsamping to {res_target}') - - class TensorVMSplit(TensorBase): - def __init__(self, aabb, gridSize, device, **kargs): + def __init__(self, aabb, gridSize, device, + use_mask=False, use_dwt=False, dwt_level=2, **kargs): super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs) + self.use_mask = use_mask + self.use_dwt = use_dwt + self.dwt_level = dwt_level + + if use_mask: + self.init_mask() def init_svd_volume(self, res, device): self.density_plane, self.density_line = self.init_one_svd( self.density_n_comp, self.gridSize, 0.1, device) self.app_plane, self.app_line = self.init_one_svd( self.app_n_comp, self.gridSize, 0.1, device) - self.basis_mat = torch.nn.Linear( + self.basis_mat = nn.Linear( sum(self.app_n_comp), self.app_dim, bias=False).to(device) + @torch.no_grad() + def init_mask(self): + self.density_plane_mask = nn.ParameterList( + [nn.Parameter(torch.zeros_like(self.density_plane[i])) + for i in range(3)]) + self.density_line_mask = nn.ParameterList( + [nn.Parameter(torch.zeros_like(self.density_line[i])) + for i in range(3)]) + self.app_plane_mask = nn.ParameterList( + [nn.Parameter(torch.zeros_like(self.app_plane[i])) + for i in range(3)]) + self.app_line_mask = nn.ParameterList( + [nn.Parameter(torch.zeros_like(self.app_line[i])) + for i in range(3)]) + def init_one_svd(self, n_component, gridSize, scale, device): plane_coef, line_coef = [], [] for i in range(len(self.vecMode)): vec_id = self.vecMode[i] mat_id_0, mat_id_1 = self.matMode[i] - plane_coef.append(torch.nn.Parameter( + plane_coef.append(nn.Parameter( scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) - line_coef.append(torch.nn.Parameter( + line_coef.append(nn.Parameter( scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) - return (torch.nn.ParameterList(plane_coef).to(device), - torch.nn.ParameterList(line_coef).to(device)) + return (nn.ParameterList(plane_coef).to(device), + nn.ParameterList(line_coef).to(device)) + + def get_optparam_groups(self, lr0=0.02, lr1=0.001): + grad_vars = [{'params': self.density_line, 'lr': lr0}, + {'params': self.density_plane, 'lr': lr0}, + {'params': self.app_line, 'lr': lr0}, + {'params': self.app_plane, 'lr': lr0}, + {'params': self.basis_mat.parameters(), 'lr':lr1}] + + if isinstance(self.renderModule, nn.Module): + grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr1}] + + if self.use_mask: + grad_vars += [{'params': self.density_plane_mask, 'lr': lr0}, + {'params': self.density_line_mask, 'lr': lr0}, + {'params': self.app_plane_mask, 'lr': lr0}, + {'params': self.app_line_mask, 'lr': lr0}] - def get_optparam_groups(self, lr_init_spatialxyz=0.02, - lr_init_network=0.001): - grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, - {'params': self.density_plane, 'lr': lr_init_spatialxyz}, - {'params': self.app_line, 'lr': lr_init_spatialxyz}, - {'params': self.app_plane, 'lr': lr_init_spatialxyz}, - {'params': self.basis_mat.parameters(), - 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): - grad_vars += [{'params':self.renderModule.parameters(), - 'lr':lr_init_network}] return grad_vars def compute_densityfeature(self, points): @@ -170,11 +98,20 @@ def compute_densityfeature(self, points): sigma_feature = torch.zeros((points.shape[0],), device=points.device) for idx in range(len(self.density_plane)): - # plane = self.density_plane[idx] - # line = self.density_line[idx] plane = min_max_quantize(self.density_plane[idx], self.grid_bit) line = min_max_quantize(self.density_line[idx], self.grid_bit) + if self.use_mask: + mask = torch.sigmoid(self.density_plane_mask[idx]) + plane = (plane * (mask >= 0.5) - plane * mask).detach() \ + + plane * mask + mask = torch.sigmoid(self.density_line_mask[idx]) + line = (line * (mask >= 0.5) - line * mask).detach() \ + + line * mask + + if self.use_dwt: + plane = inverse(plane, self.dwt_level) + plane_coef_point = F.grid_sample( plane, coordinate_plane[[idx]], align_corners=True).view(-1, *points.shape[:1]) @@ -197,11 +134,12 @@ def compute_appfeature(self, points): plane_coef_point, line_coef_point = [], [] for idx in range(len(self.app_plane)): - # plane = self.app_plane[idx] - # line = self.app_line[idx] plane = min_max_quantize(self.app_plane[idx], self.grid_bit) line = min_max_quantize(self.app_line[idx], self.grid_bit) + if self.use_dwt: + plane = inverse(plane, self.dwt_level) + plane_coef_point.append(F.grid_sample( plane, coordinate_plane[[idx]], align_corners=True).view(-1, *points.shape[:1])) @@ -221,6 +159,13 @@ def upsample_volume_grid(self, res_target): self.density_plane, self.density_line = self.up_sampling_VM( self.density_plane, self.density_line, res_target) + if self.use_mask: + self.app_plane_mask, self.app_line_mask = self.up_sampling_VM( + self.app_plane_mask, self.app_line_mask, res_target) + self.density_plane_mask, self.density_line_mask = \ + self.up_sampling_VM(self.density_plane_mask, + self.density_line_mask, res_target) + self.update_stepSize(res_target) print(f'upsamping to {res_target}') @@ -229,11 +174,17 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target): for i in range(len(self.vecMode)): vec_id = self.vecMode[i] mat_id_0, mat_id_1 = self.matMode[i] - plane_coef[i] = torch.nn.Parameter( + + if self.use_dwt: + plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level)) + plane_coef[i] = nn.Parameter( F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', align_corners=True)) - line_coef[i] = torch.nn.Parameter( + if self.use_dwt: + plane_coef[i].set_(forward(plane_coef[i], self.dwt_level)) + + line_coef[i] = nn.Parameter( F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) @@ -242,40 +193,59 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target): @torch.no_grad() def shrink(self, new_aabb): print("====> shrinking ...") - xyz_min, xyz_max = new_aabb - t_l = (xyz_min - self.aabb[0]) / self.units - t_l = torch.round(t_l).long() - - b_r = (xyz_max - self.aabb[0]) / self.units - b_r = torch.round(b_r).long() + 1 - b_r = torch.stack([b_r, self.gridSize]).amin(0) + unit = 16 # unit for DWT for i in range(len(self.vecMode)): + # Lines mode0 = self.vecMode[i] - self.density_line[i] = torch.nn.Parameter( - self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]) - self.app_line[i] = torch.nn.Parameter( - self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]) + steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0] + steps = int(steps / unit) * unit - mode0, mode1 = self.matMode[i] - self.density_plane[i] = torch.nn.Parameter( - self.density_plane[i].data[...,t_l[mode1]:b_r[mode1], - t_l[mode0]:b_r[mode0]]) - self.app_plane[i] = torch.nn.Parameter( - self.app_plane[i].data[...,t_l[mode1]:b_r[mode1], - t_l[mode0]:b_r[mode0]]) + grid = torch.linspace(new_aabb[0][mode0], new_aabb[1][mode0], + steps).to(self.density_line[i].device) + grid = F.pad(grid.reshape(1, -1, 1, 1), (0, 1)) - if not torch.all(self.alphaMask.gridSize == self.gridSize): - t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1) - correct_aabb = torch.zeros_like(new_aabb) - correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1] - correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1] - print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) - new_aabb = correct_aabb + self.density_line[i] = nn.Parameter( + F.grid_sample(self.density_line[i], grid, align_corners=True)) + self.app_line[i] = nn.Parameter( + F.grid_sample(self.app_line[i], grid, align_corners=True)) + + # Planes + mode0, mode1 = self.matMode[i] + if self.use_dwt: + self.density_plane[i].set_(inverse(self.density_plane[i], + self.dwt_level)) + self.app_plane[i].set_(inverse(self.app_plane[i], + self.dwt_level)) + + steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0] + steps = int(steps / unit) * unit + grid0 = torch.linspace(new_aabb[0][mode0], new_aabb[1][mode0], + steps).to(self.density_line[i].device) + + steps = (new_aabb[1][mode1]-new_aabb[0][mode1]) / self.units[mode1] + steps = int(steps / unit) * unit + grid1 = torch.linspace(new_aabb[0][mode1], new_aabb[1][mode1], + steps).to(self.density_line[i].device) + grid = torch.stack(torch.meshgrid(grid0, grid1), -1).unsqueeze(0) + + self.density_plane[i] = nn.Parameter( + F.grid_sample(self.density_plane[i], grid, align_corners=True)) + self.app_plane[i] = nn.Parameter( + F.grid_sample(self.app_plane[i], grid, align_corners=True)) + + if self.use_dwt: + self.density_plane[i].set_(forward(self.density_plane[i], + self.dwt_level)) + self.app_plane[i].set_(forward(self.app_plane[i], + self.dwt_level)) - newSize = b_r - t_l self.aabb = new_aabb - self.update_stepSize((newSize[0], newSize[1], newSize[2])) + self.update_stepSize( + tuple(reversed([p.shape[-2] for p in self.density_line]))) + + if self.use_mask: + self.init_mask() def vectorDiffs(self, vector_comps): breakpoint() @@ -322,7 +292,7 @@ def __init__(self, aabb, gridSize, device, **kargs): def init_svd_volume(self, res, device): self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device) self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device) - self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device) + self.basis_mat = nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device) def init_one_svd(self, n_component, gridSize, scale, device): @@ -330,15 +300,15 @@ def init_one_svd(self, n_component, gridSize, scale, device): for i in range(len(self.vecMode)): vec_id = self.vecMode[i] line_coef.append( - torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1)))) - return torch.nn.ParameterList(line_coef).to(device) + nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1)))) + return nn.ParameterList(line_coef).to(device) def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): + if isinstance(self.renderModule, nn.Module): grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] return grad_vars @@ -381,9 +351,9 @@ def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target): for i in range(len(self.vecMode)): vec_id = self.vecMode[i] - density_line_coef[i] = torch.nn.Parameter( + density_line_coef[i] = nn.Parameter( F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - app_line_coef[i] = torch.nn.Parameter( + app_line_coef[i] = nn.Parameter( F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) return density_line_coef, app_line_coef @@ -404,13 +374,12 @@ def shrink(self, new_aabb): t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 b_r = torch.stack([b_r, self.gridSize]).amin(0) - for i in range(len(self.vecMode)): mode0 = self.vecMode[i] - self.density_line[i] = torch.nn.Parameter( + self.density_line[i] = nn.Parameter( self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:] ) - self.app_line[i] = torch.nn.Parameter( + self.app_line[i] = nn.Parameter( self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:] ) diff --git a/TensoRF/models/tensorBase.py b/TensoRF/models/tensorBase.py index 487d35c..58301ba 100644 --- a/TensoRF/models/tensorBase.py +++ b/TensoRF/models/tensorBase.py @@ -36,10 +36,10 @@ def SHRender(xyz_sampled, viewdirs, features): def RGBRender(xyz_sampled, viewdirs, features): - rgb = features return rgb + class AlphaGridMask(torch.nn.Module): def __init__(self, device, aabb, alpha_volume): super(AlphaGridMask, self).__init__() @@ -66,17 +66,24 @@ def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): super(MLPRender_Fea, self).__init__() self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel + # self.pospe = 4 + # self.in_mlpC = 3 * ((1 + 2*viewpe) + (1+ 2*self.pospe)) \ + # + inChanel * (1 + 2*feape) self.viewpe = viewpe self.feape = feape layer1 = torch.nn.Linear(self.in_mlpC, featureC) layer2 = torch.nn.Linear(featureC, featureC) - layer3 = torch.nn.Linear(featureC,3) + layer3 = torch.nn.Linear(featureC, 3) - self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) + self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), + layer2, torch.nn.ReLU(inplace=True), + layer3) torch.nn.init.constant_(self.mlp[-1].bias, 0) def forward(self, pts, viewdirs, features): - indata = [features, viewdirs] + indata = [features, viewdirs] # , pts] + # if self.pospe > 0: + # indata += [positional_encoding(pts, self.pospe)] if self.feape > 0: indata += [positional_encoding(features, self.feape)] if self.viewpe > 0: @@ -87,11 +94,12 @@ def forward(self, pts, viewdirs, features): return rgb + class MLPRender_PE(torch.nn.Module): def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): super(MLPRender_PE, self).__init__() - self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # + # self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # self.viewpe = viewpe self.pospe = pospe layer1 = torch.nn.Linear(self.in_mlpC, featureC) @@ -113,6 +121,7 @@ def forward(self, pts, viewdirs, features): return rgb + class MLPRender(torch.nn.Module): def __init__(self,inChanel, viewpe=6, featureC=128): super(MLPRender, self).__init__() @@ -134,17 +143,18 @@ def forward(self, pts, viewdirs, features): mlp_in = torch.cat(indata, dim=-1) rgb = self.mlp(mlp_in) rgb = torch.sigmoid(rgb) - return rgb - class TensorBase(torch.nn.Module): - def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27, - shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], - density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001, - pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0, - fea2denseAct = 'softplus', grid_bit=32): + def __init__(self, aabb, gridSize, device, density_n_comp=8, + appearance_n_comp=24, app_dim=27, + shadingMode='MLP_PE', + alphaMask=None, near_far=[2.0, 6.0], density_shift=-10, + alphaMask_thres=0.001, distance_scale=25, + rayMarch_weight_thres=0.0001, pos_pe=6, view_pe=6, fea_pe=6, + featureC=128, step_ratio=2.0, fea2denseAct='softplus', + grid_bit=32): super(TensorBase, self).__init__() self.density_n_comp = density_n_comp @@ -163,7 +173,6 @@ def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp self.near_far = near_far self.step_ratio = step_ratio - self.update_stepSize(gridSize) self.matMode = [[0,1], [0,2], [1,2]] @@ -204,7 +213,7 @@ def update_stepSize(self, gridSize): self.units=self.aabbSize / (self.gridSize-1) self.stepSize=torch.mean(self.units)*self.step_ratio self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) - self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 + self.nSamples = int((self.aabbDiag / self.stepSize).item()) + 1 print("sampling step size: ", self.stepSize) print("sampling number: ", self.nSamples) @@ -269,7 +278,6 @@ def load(self, ckpt): self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device)) self.load_state_dict(ckpt['state_dict']) - def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): N_samples = N_samples if N_samples > 0 else self.nSamples near, far = self.near_far @@ -308,7 +316,6 @@ def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): return rays_pts, interpx, ~mask_outbbox - def shrink(self, new_aabb, voxel_size): pass @@ -468,7 +475,6 @@ def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_sa if white_bg or (is_train and torch.rand((1,))<0.5): rgb_map = rgb_map + (1. - acc_map[..., None]) - rgb_map = rgb_map.clamp(0,1) with torch.no_grad(): diff --git a/TensoRF/opt.py b/TensoRF/opt.py index 1988add..7552ffd 100644 --- a/TensoRF/opt.py +++ b/TensoRF/opt.py @@ -1,5 +1,6 @@ import configargparse + def config_parser(cmd=None): parser = configargparse.ArgumentParser() parser.add_argument('--config', is_config_file=True, @@ -29,7 +30,6 @@ def config_parser(cmd=None): parser.add_argument('--dataset_name', type=str, default='blender', choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']) - # training options # learning rate parser.add_argument("--lr_init", type=float, default=0.02, @@ -44,6 +44,7 @@ def config_parser(cmd=None): help='reset lr to inital after upsampling') # loss + parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument("--L1_weight_inital", type=float, default=0.0, help='loss weight') parser.add_argument("--L1_weight_rest", type=float, default=0, @@ -70,7 +71,12 @@ def config_parser(cmd=None): parser.add_argument("--density_shift", type=float, default=-10, help='shift density in softplus; making density = 0 when feature == 0') + # My Options parser.add_argument("--grid_bit", type=int, default=32) + parser.add_argument("--use_mask", action='store_true') + parser.add_argument("--mask_weight", type=float, default=0) + parser.add_argument("--use_dwt", action='store_true') + parser.add_argument("--dwt_level", type=int, default=2) # network decoder parser.add_argument("--shadingMode", type=str, default="MLP_PE", @@ -83,8 +89,6 @@ def config_parser(cmd=None): help='number of pe for features') parser.add_argument("--featureC", type=int, default=128, help='hidden feature channel in MLP') - - parser.add_argument("--ckpt", type=str, default=None, help='specific weights npy file to reload for coarse network') @@ -106,13 +110,10 @@ def config_parser(cmd=None): help='sample point each ray, pass 1e6 if automatic adjust') parser.add_argument('--step_ratio',type=float,default=0.5) - ## blender flags parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') - - parser.add_argument('--N_voxel_init', type=int, default=100**3) @@ -134,3 +135,4 @@ def config_parser(cmd=None): return parser.parse_args(cmd) else: return parser.parse_args() + diff --git a/TensoRF/renderer.py b/TensoRF/renderer.py index 9ea0646..e28940c 100644 --- a/TensoRF/renderer.py +++ b/TensoRF/renderer.py @@ -1,9 +1,13 @@ -import torch,os,imageio,sys +import imageio +import os +import sys +import torch from tqdm.auto import tqdm + from dataLoader.ray_utils import get_rays -from models.tensoRF import TensorVM, TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask -from utils import * from dataLoader.ray_utils import ndc_rays_blender +from models.tensoRF import TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask +from utils import * def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'): diff --git a/TensoRF/train.py b/TensoRF/train.py index 37f06bc..ec5cf8b 100644 --- a/TensoRF/train.py +++ b/TensoRF/train.py @@ -1,19 +1,14 @@ - +import datetime import os +import random +import sys +from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm -from opt import config_parser - - -import json, random +from dataLoader import dataset_dict +from opt import config_parser from renderer import * from utils import * -from torch.utils.tensorboard import SummaryWriter -import datetime - -from dataLoader import dataset_dict -import sys - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -49,7 +44,6 @@ def nextids(self): @torch.no_grad() def export_mesh(args): - ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) @@ -57,14 +51,16 @@ def export_mesh(args): tensorf.load(ckpt) alpha,_ = tensorf.getDenseAlpha() - convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005) + convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply', + bbox=tensorf.aabb.cpu(), level=0.005) @torch.no_grad() def render_test(args): # init dataset dataset = dataset_dict[args.dataset_name] - test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) + test_dataset = dataset(args.datadir, split='test', + downsample=args.downsample_train, is_stack=True) white_bg = test_dataset.white_bg ndc_ray = args.ndc_ray @@ -97,8 +93,8 @@ def render_test(args): evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) -def reconstruction(args): +def reconstruction(args): # init dataset dataset = dataset_dict[args.dataset_name] train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) @@ -112,14 +108,12 @@ def reconstruction(args): update_AlphaMask_list = args.update_AlphaMask_list n_lamb_sigma = args.n_lamb_sigma n_lamb_sh = args.n_lamb_sh - if args.add_timestamp: logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' else: logfolder = f'{args.basedir}/{args.expname}' - # init log file os.makedirs(logfolder, exist_ok=True) os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) @@ -129,6 +123,7 @@ def reconstruction(args): # init parameters aabb = train_dataset.scene_bbox.to(device) + reso_cur = N_to_reso(args.N_voxel_init, aabb) nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) @@ -139,13 +134,22 @@ def reconstruction(args): tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) else: - tensorf = eval(args.model_name)(aabb, reso_cur, device, - density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far, - shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale, - pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct, grid_bit=args.grid_bit) - - print(tensorf) - print(sum([p.numel() for p in tensorf.parameters()]) * 16 / 8_388_608) + tensorf = eval(args.model_name)( + aabb, reso_cur, device, + density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, + app_dim=args.data_dim_color, near_far=near_far, + shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, + density_shift=args.density_shift, + distance_scale=args.distance_scale, + pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, + featureC=args.featureC, step_ratio=args.step_ratio, + fea2denseAct=args.fea2denseAct, + grid_bit=args.grid_bit, + use_mask=args.use_mask, + use_dwt=args.use_dwt, dwt_level=args.dwt_level) + + # print(tensorf) + print(f'{sum([p.numel() for p in tensorf.parameters()])*32/8_388_608}MB') grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis) if args.lr_decay_iters > 0: @@ -155,14 +159,12 @@ def reconstruction(args): lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) - - optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) - + optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99), + weight_decay=args.weight_decay) #linear in logrithmic space N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:] - torch.cuda.empty_cache() PSNRs,PSNRs_test = [],[0] @@ -215,6 +217,17 @@ def reconstruction(args): total_loss = total_loss + loss_tv summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration) + if args.use_mask and args.mask_weight > 0: + mask_loss = sum([p.sum() + for p in tensorf.density_plane_mask.parameters()])\ + + sum([p.sum() + for p in tensorf.density_line_mask.parameters()])\ + + sum([p.sum() + for p in tensorf.app_plane_mask.parameters()])\ + + sum([p.sum() + for p in tensorf.app_line_mask.parameters()]) + total_loss = total_loss + args.mask_weight * mask_loss + optimizer.zero_grad() total_loss.backward() optimizer.step() @@ -276,32 +289,60 @@ def reconstruction(args): lr_scale = 1 #0.1 ** (iteration / args.n_iters) else: lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters) - grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale) - optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) + grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, + args.lr_basis*lr_scale) + optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99), + weight_decay=args.weight_decay) + + if iteration + 1 in [500, 1000, 2500, 5000, 10000, 20000]: + print() tensorf.save(f'{logfolder}/{args.expname}.th') grid, non_grid = tensorf_param_count(tensorf) + if args.use_mask: + grid = grid / 2 # dont count masks grid_bytes = grid * args.grid_bit / 8 non_grid_bytes = non_grid * 4 print(f'total: {(grid_bytes + non_grid_bytes)/1_048_576:.3f}MB ' f'(G ({args.grid_bit}bit): {grid_bytes/1_048_576:.3f}MB) ' f'(N: {non_grid_bytes/1_048_576:3f}MB)') + if args.use_mask: + flat_mask = torch.cat([torch.cat([p[0].flatten(), p[1].flatten(), + p[2].flatten()]) + for p in [tensorf.density_plane_mask, + tensorf.density_line_mask, + tensorf.app_plane_mask, + tensorf.app_line_mask]]) + ratio = (flat_mask >= 0).float().mean() + print(f'non-masked ratio: {ratio:.4f}') + grid_bytes = grid_bytes * ratio + print(f'masked_total: {(grid_bytes + non_grid_bytes)/1_048_576:.3f}MB ' + f'(G ({args.grid_bit}bit): {grid_bytes/1_048_576:.3f}MB) ' + f'(N: {non_grid_bytes/1_048_576:3f}MB)') + if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) - train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) - PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', - N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) - print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') + train_dataset = dataset(args.datadir, split='train', + downsample=args.downsample_train, is_stack=True) + PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, + f'{logfolder}/imgs_train_all/', + N_vis=-1, N_samples=-1, white_bg=white_bg, + ndc_ray=ndc_ray, device=device) + print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} ' + f'<========================') if args.render_test: os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) - PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/imgs_test_all/', - N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) - summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) - - print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') + PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, + f'{logfolder}/imgs_test_all/', + N_vis=-1, N_samples=-1, white_bg=white_bg, + ndc_ray=ndc_ray, device=device) + summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), + global_step=iteration) + print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} ' + f'<========================') if args.render_path: c2ws = test_dataset.render_path diff --git a/TensoRF/utils.py b/TensoRF/utils.py index 3c29586..7be5345 100644 --- a/TensoRF/utils.py +++ b/TensoRF/utils.py @@ -50,18 +50,18 @@ def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): x_ = T.ToTensor()(x_) # (3, H, W) return x_, [mi,ma] -def N_to_reso(n_voxels, bbox): + +def N_to_reso(n_voxels, bbox, unit=16): xyz_min, xyz_max = bbox dim = len(xyz_min) voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) - return ((xyz_max - xyz_min) / voxel_size).long().tolist() + return (((xyz_max - xyz_min) / voxel_size / unit).long() * unit).tolist() + def cal_n_samples(reso, step_ratio=0.5): return int(np.linalg.norm(reso)/step_ratio) - - __LPIPS__ = {} def init_lpips(net_name, device): assert net_name in ['alex', 'vgg'] @@ -69,6 +69,7 @@ def init_lpips(net_name, device): print(f'init_lpips: lpips_{net_name}') return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) + def rgb_lpips(np_gt, np_im, net_name, device): if net_name not in __LPIPS__: __LPIPS__[net_name] = init_lpips(net_name, device) @@ -219,3 +220,4 @@ def convert_sdf_samples_to_ply( ply_data = plyfile.PlyData([el_verts, el_faces]) print("saving mesh to %s" % (ply_filename_out)) ply_data.write(ply_filename_out) +