diff --git a/TensoRF/models/tensoRF.py b/TensoRF/models/tensoRF.py index c7f4869..d03f598 100644 --- a/TensoRF/models/tensoRF.py +++ b/TensoRF/models/tensoRF.py @@ -323,6 +323,237 @@ def TV_loss_app(self, reg): return total +class TriPlane(TensorBase): + def __init__(self, aabb, gridSize, device, + use_mask=False, use_dwt=False, dwt_level=2, + trans_func='bior4.4', **kargs): + super(TriPlane, self).__init__(aabb, gridSize, device, **kargs) + self.use_mask = use_mask + self.use_dwt = use_dwt + self.dwt_level = dwt_level + self.trans_func = trans_func + + if use_mask: + self.init_mask() + + def get_kwargs(self): + return { + 'aabb': self.aabb, + 'gridSize':self.gridSize.tolist(), + 'density_n_comp': self.density_n_comp, + 'appearance_n_comp': self.app_n_comp, + 'app_dim': self.app_dim, + + 'density_shift': self.density_shift, + 'alphaMask_thres': self.alphaMask_thres, + 'distance_scale': self.distance_scale, + 'rayMarch_weight_thres': self.rayMarch_weight_thres, + 'fea2denseAct': self.fea2denseAct, + + 'near_far': self.near_far, + 'step_ratio': self.step_ratio, + + 'shadingMode': self.shadingMode, + 'pos_pe': self.pos_pe, + 'view_pe': self.view_pe, + 'fea_pe': self.fea_pe, + 'featureC': self.featureC, + + 'grid_bit': self.grid_bit, + 'use_mask': self.use_mask, + 'use_dwt': self.use_dwt, + 'dwt_level': self.dwt_level, + 'trans_func': self.trans_func, + } + + def init_svd_volume(self, res, device): + self.density_plane= self.init_one_svd( + self.density_n_comp, self.gridSize, 0.1, device) + self.app_plane = self.init_one_svd( + self.app_n_comp, self.gridSize, 0.1, device) + self.basis_mat = nn.Linear( + sum(self.app_n_comp), self.app_dim, bias=False).to(device) + + @torch.no_grad() + def init_mask(self): + self.density_plane_mask = nn.ParameterList( + [nn.Parameter(torch.ones_like(self.density_plane[i])) + for i in range(3)]) + self.app_plane_mask = nn.ParameterList( + [nn.Parameter(torch.ones_like(self.app_plane[i])) + for i in range(3)]) + + def init_one_svd(self, n_component, gridSize, scale, device): + plane_coef = [] + for i in range(len(self.matMode)): + mat_id_0, mat_id_1 = self.matMode[i] + plane_coef.append(nn.Parameter( + scale * torch.randn((1, n_component[i], gridSize[mat_id_1], + gridSize[mat_id_0])))) + + return nn.ParameterList(plane_coef).to(device) + + def get_optparam_groups(self, lr0=0.02, lr1=0.001): + grad_vars = [{'params': self.density_plane, 'lr': lr0}, + {'params': self.app_plane, 'lr': lr0}, + {'params': self.basis_mat.parameters(), 'lr':lr1}] + + if isinstance(self.renderModule, nn.Module): + grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr1}] + + if self.use_mask: + grad_vars += [{'params': self.density_plane_mask, 'lr': lr0}, + {'params': self.app_plane_mask, 'lr': lr0}] + + return grad_vars + + def compute_densityfeature(self, points): + # plane + line basis + # [3, B, 1, 2] + coordinate_plane = points[..., self.matMode].transpose(0, -2) \ + .view(3, -1, 1, 2) + + sigma_feature = torch.zeros((points.shape[0],), device=points.device) + + for idx in range(len(self.density_plane)): + plane = min_max_quantize(self.density_plane[idx], self.grid_bit) + + if self.use_mask: + mask = torch.sigmoid(self.density_plane_mask[idx]) + plane = (plane * (mask >= 0.5) - plane * mask).detach() \ + + plane * mask + + if self.use_dwt: + plane = inverse(plane, self.dwt_level, self.trans_func) + + plane_coef_point = F.grid_sample( + plane, coordinate_plane[[idx]], + align_corners=True).view(-1, *points.shape[:1]) + + sigma_feature += torch.sum(plane_coef_point, dim=0) + + return sigma_feature + + def compute_appfeature(self, points): + # plane + line basis + # [3, B, 1, 2] + coordinate_plane = points[..., self.matMode].transpose(0, -2) \ + .view(3, -1, 1, 2) + + plane_coef_point = [] + for idx in range(len(self.app_plane)): + plane = min_max_quantize(self.app_plane[idx], self.grid_bit) + + if self.use_mask: + mask = torch.sigmoid(self.app_plane_mask[idx]) + plane = (plane * (mask >= 0.5) - plane * mask).detach() \ + + plane * mask + + if self.use_dwt: + plane = inverse(plane, self.dwt_level, self.trans_func) + + plane_coef_point.append(F.grid_sample( + plane, coordinate_plane[[idx]], + align_corners=True).view(-1, *points.shape[:1])) + + plane_coef_point = torch.cat(plane_coef_point) + + return self.basis_mat(plane_coef_point.T) + + @torch.no_grad() + def upsample_volume_grid(self, res_target): + self.app_plane = self.up_sampling_VM(self.app_plane, res_target) + self.density_plane = self.up_sampling_VM(self.density_plane, res_target) + + if self.use_mask: + self.app_plane_mask = self.up_sampling_VM(self.app_plane_mask, res_target) + self.density_plane_mask = self.up_sampling_VM(self.density_plane_mask, res_target) + + self.update_stepSize(res_target) + print(f'upsamping to {res_target}') + + @torch.no_grad() + def up_sampling_VM(self, plane_coef, res_target): + for i in range(len(self.matMode)): + mat_id_0, mat_id_1 = self.matMode[i] + + if self.use_dwt: + plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level, self.trans_func)) + + plane_coef[i] = nn.Parameter( + F.interpolate(plane_coef[i].data, + size=(res_target[mat_id_1], res_target[mat_id_0]), + mode='bilinear', align_corners=True)) + + if self.use_dwt: + plane_coef[i].set_(forward(plane_coef[i], self.dwt_level, self.trans_func)) + + return plane_coef + + @torch.no_grad() + def shrink(self, new_aabb): + print("====> shrinking ...") + unit = 16 # unit for DWT + + for i in range(len(self.matMode)): + # Planes + mode0, mode1 = self.matMode[i] + if self.use_dwt: + self.density_plane[i].set_(inverse(self.density_plane[i], + self.dwt_level, self.trans_func)) + self.app_plane[i].set_(inverse(self.app_plane[i], + self.dwt_level, self.trans_func)) + + steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0] + steps = int(steps / unit) * unit + grid0 = torch.linspace(new_aabb[0][mode0], new_aabb[1][mode0], + steps).to(self.density_plane[i].device) + + steps = (new_aabb[1][mode1]-new_aabb[0][mode1]) / self.units[mode1] + steps = int(steps / unit) * unit + grid1 = torch.linspace(new_aabb[0][mode1], new_aabb[1][mode1], + steps).to(self.density_plane[i].device) + grid = torch.stack(torch.meshgrid(grid0, grid1), -1).unsqueeze(0) + + self.density_plane[i] = nn.Parameter( + F.grid_sample(self.density_plane[i], grid, align_corners=True)) + self.app_plane[i] = nn.Parameter( + F.grid_sample(self.app_plane[i], grid, align_corners=True)) + + if self.use_dwt: + self.density_plane[i].set_(forward(self.density_plane[i], + self.dwt_level, self.trans_func)) + self.app_plane[i].set_(forward(self.app_plane[i], + self.dwt_level, self.trans_func)) + + self.aabb = new_aabb + + Y, X = self.density_plane[0].shape[-2:] + Z = self.density_plane[1].shape[-2] + self.update_stepSize((X,Y,Z)) + + if self.use_mask: + self.init_mask() + + def density_L1(self): + total = 0 + for idx in range(len(self.density_plane)): + total = total + torch.mean(torch.abs(self.density_plane[idx])) + return total + + def TV_loss_density(self, reg): + total = 0 + for idx in range(len(self.density_plane)): + total = total + reg(self.density_plane[idx]) * 1e-2 + return total + + def TV_loss_app(self, reg): + total = 0 + for idx in range(len(self.app_plane)): + total = total + reg(self.app_plane[idx]) * 1e-2 + return total + + class TensorCP(TensorBase): def __init__(self, aabb, gridSize, device, **kargs): super(TensorCP, self).__init__(aabb, gridSize, device, **kargs) diff --git a/TensoRF/opt.py b/TensoRF/opt.py index 5535b2a..562128f 100644 --- a/TensoRF/opt.py +++ b/TensoRF/opt.py @@ -21,7 +21,7 @@ def config_parser(cmd=None): parser.add_argument('--downsample_test', type=float, default=1.0) parser.add_argument('--model_name', type=str, default='TensorVMSplit', - choices=['TensorVMSplit', 'TensorCP']) + choices=['TensorVMSplit', 'TensorCP', 'TriPlane']) # loader options parser.add_argument("--batch_size", type=int, default=4096) diff --git a/TensoRF/renderer.py b/TensoRF/renderer.py index e28940c..1274391 100644 --- a/TensoRF/renderer.py +++ b/TensoRF/renderer.py @@ -6,7 +6,7 @@ from dataLoader.ray_utils import get_rays from dataLoader.ray_utils import ndc_rays_blender -from models.tensoRF import TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask +from models.tensoRF import TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask, TriPlane from utils import * diff --git a/TensoRF/train.py b/TensoRF/train.py index 4f7e09c..3bec285 100644 --- a/TensoRF/train.py +++ b/TensoRF/train.py @@ -75,6 +75,11 @@ def render_test(args): tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) + _, _, Z, Y, X = tensorf.alphaMask.alpha_volume.shape + tensorf.alphaMask = None + tensorf.alpha_offset = 0 + tensorf.updateAlphaMask((X,Y,Z)) + logfolder = os.path.dirname(args.ckpt) if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) @@ -85,8 +90,9 @@ def render_test(args): if args.render_test: os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) - evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', + PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) + print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================') if args.render_path: c2ws = test_dataset.render_path @@ -222,10 +228,11 @@ def reconstruction(args): mask_loss = sum([p.sum() for p in tensorf.density_plane_mask.parameters()])\ + sum([p.sum() + for p in tensorf.app_plane_mask.parameters()]) + if hasattr(tensorf, "density_line_mask"): + mask_loss += sum([p.sum() for p in tensorf.density_line_mask.parameters()])\ - + sum([p.sum() - for p in tensorf.app_plane_mask.parameters()])\ - + sum([p.sum() + + sum([p.sum() for p in tensorf.app_line_mask.parameters()]) total_loss = total_loss + args.mask_weight * mask_loss @@ -306,16 +313,18 @@ def reconstruction(args): with torch.no_grad(): for i in range(3): tensorf.density_plane[i].set_(min_max_quantize(tensorf.density_plane[i], args.grid_bit) * (tensorf.density_plane_mask[i] >= 0)) - tensorf.density_line[i].set_(min_max_quantize(tensorf.density_line[i], args.grid_bit) * (tensorf.density_line_mask[i] >= 0)) tensorf.app_plane[i].set_(min_max_quantize(tensorf.app_plane[i], args.grid_bit) * (tensorf.app_plane_mask[i] >= 0)) - tensorf.app_line[i].set_(min_max_quantize(tensorf.app_line[i], args.grid_bit) * (tensorf.app_line_mask[i] >= 0)) + if hasattr(tensorf, "density_line_mask"): + tensorf.density_line[i].set_(min_max_quantize(tensorf.density_line[i], args.grid_bit) * (tensorf.density_line_mask[i] >= 0)) + tensorf.app_line[i].set_(min_max_quantize(tensorf.app_line[i], args.grid_bit) * (tensorf.app_line_mask[i] >= 0)) tensorf.use_mask = False del tensorf.density_plane_mask - del tensorf.density_line_mask del tensorf.app_plane_mask - del tensorf.app_line_mask + if hasattr(tensorf, "density_line_mask"): + del tensorf.density_line_mask + del tensorf.app_line_mask grid, non_grid = tensorf_param_count(tensorf) grid_bytes = grid * args.grid_bit / 8 @@ -325,13 +334,21 @@ def reconstruction(args): f'(N: {non_grid_bytes/1_048_576:3f}MB)') if args.use_mask: - flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit), - min_max_quantize(p[1].flatten(), args.grid_bit), - min_max_quantize(p[2].flatten(), args.grid_bit)]) - for p in [tensorf.density_plane, - tensorf.density_line, - tensorf.app_plane, - tensorf.app_line]]) + if hasattr(tensorf, "density_line"): + flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit), + min_max_quantize(p[1].flatten(), args.grid_bit), + min_max_quantize(p[2].flatten(), args.grid_bit)]) + for p in [tensorf.density_plane, + tensorf.density_line, + tensorf.app_plane, + tensorf.app_line]]) + else: + flat_mask = torch.cat([torch.cat([min_max_quantize(p[0].flatten(), args.grid_bit), + min_max_quantize(p[1].flatten(), args.grid_bit), + min_max_quantize(p[2].flatten(), args.grid_bit)]) + for p in [tensorf.density_plane, + tensorf.app_plane]]) + ratio = (flat_mask != 0).float().mean() print(f'non-masked ratio: {ratio:.4f}') grid_bytes = grid_bytes * ratio @@ -339,13 +356,13 @@ def reconstruction(args): f'(G ({args.grid_bit}bit): {grid_bytes/1_048_576:.3f}MB) ' f'(N: {non_grid_bytes/1_048_576:3f}MB)') + tensorf.save(f'{logfolder}/{args.expname}.th') # Alpha mask reconstruction _, _, Z, Y, X = tensorf.alphaMask.alpha_volume.shape tensorf.alphaMask = None tensorf.alpha_offset = 0 tensorf.updateAlphaMask((X,Y,Z)) - tensorf.save(f'{logfolder}/{args.expname}.th') if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) @@ -395,4 +412,4 @@ def reconstruction(args): render_test(args) else: reconstruction(args) - + \ No newline at end of file