diff --git a/TensoRF/models/dwt.py b/TensoRF/models/dwt.py index dbc73b6..09fdd2e 100644 --- a/TensoRF/models/dwt.py +++ b/TensoRF/models/dwt.py @@ -16,7 +16,7 @@ def split2d(inputs, level=4): inputs[..., res0//(2**(i+1)):res0//(2**i), :res1//(2**(i+1))], inputs[..., res0//(2**(i+1)):res0//(2**i), - res1//(2**(i+1)):res1//(2**i)]], 2) # /(level-i+1) + res1//(2**(i+1)):res1//(2**i)]], 2)/(level-i+1) for i in range(level) ] @@ -36,21 +36,26 @@ def split1d(inputs, level=4): # inverse and forward -def inverse(inputs, level=4): - return DWTInverse(wave='bior4.4', mode='periodization').to(inputs.device)\ +def inverse(inputs, level=4, trans_func='bior4.4'): + if trans_func == 'cosine': + return idctn(inputs, (-2, -1)) + return DWTInverse(wave=trans_func, mode='periodization').to(inputs.device)\ (split2d(inputs, level)) -def forward(inputs, level=4): +def forward(inputs, level=4, trans_func='bior4.4'): assert inputs.size(-1) % 2**level == 0 assert inputs.size(-2) % 2**level == 0 - yl, yh = DWTForward(wave='bior4.4', J=level, + if trans_func == 'cosine': + return dctn(inputs, (-2, -1)) + + yl, yh = DWTForward(wave=trans_func, J=level, mode='periodization').to(inputs.device)(inputs) outs = yl for i in range(level): - cf = yh[-i-1] # * (i+2) + cf = yh[-i-1] * (i+ 2) outs = torch.cat([torch.cat([outs, cf[..., 0, :, :]], -1), torch.cat([cf[..., 1, :, :], cf[..., 2, :, :]], -1)], -2) @@ -75,6 +80,55 @@ def forward1d(inputs, level=4): return outs +def dct(coefs, coords=None): + ''' + coefs: [..., C] # C: n_coefs + coords: [..., S] # S: n_samples + ''' + if coords is None: + coords = torch.ones_like(coefs) \ + * torch.arange(coefs.size(-1)).to(coefs.device) # \ + cos = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5) / coefs.size(-1) + * (torch.arange(coefs.size(-1)).to(coefs.device) + 0.5)) + return torch.einsum('...C,...SC->...S', coefs*(2/coefs.size(-1))**0.5, cos) + + +def dctn(coefs, axes=None): + if axes is None: + axes = tuple(range(len(coefs.shape))) + out = coefs + for ax in axes: + out = out.transpose(-1, ax) + out = dct(out) + out = out.transpose(-1, ax) + return out + + +def idctn(coefs, axes=None, n_out=None, **kwargs): + if axes is None: + axes = tuple(range(len(coefs.shape))) + + if n_out is None or isinstance(n_out, int): + n_out = [n_out] * len(axes) + + out = coefs + for ax, n_o in zip(axes, n_out): + out = out.transpose(-1, ax) + out = idct(out, n_o, **kwargs) + out = out.transpose(-1, ax) + return out + + +def idct(coefs, n_out=None): + N = coefs.size(-1) + if n_out is None: + n_out = N + # TYPE IV + out = torch.cos(torch.pi * (torch.arange(N).to(coefs.device) + 0.5) / N + * (torch.linspace(0, N-1, n_out).unsqueeze(-1).to(coefs.device) + 0.5)) + return torch.einsum('...C,...SC->...S', coefs*(2/N)**0.5, out) + + if __name__ == '__main__': a = torch.randn(3, 5, 64, 64).cuda() * 10 print(a.shape, inverse(a).shape) diff --git a/TensoRF/models/tensoRF.py b/TensoRF/models/tensoRF.py index fef562c..c7f4869 100644 --- a/TensoRF/models/tensoRF.py +++ b/TensoRF/models/tensoRF.py @@ -20,11 +20,13 @@ def min_max_quantize(inputs, bits): class TensorVMSplit(TensorBase): def __init__(self, aabb, gridSize, device, - use_mask=False, use_dwt=False, dwt_level=2, **kargs): + use_mask=False, use_dwt=False, dwt_level=2, + trans_func='bior4.4', **kargs): super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs) self.use_mask = use_mask self.use_dwt = use_dwt self.dwt_level = dwt_level + self.trans_func = trans_func if use_mask: self.init_mask() @@ -55,7 +57,8 @@ def get_kwargs(self): 'grid_bit': self.grid_bit, 'use_mask': self.use_mask, 'use_dwt': self.use_dwt, - 'dwt_level': self.dwt_level + 'dwt_level': self.dwt_level, + 'trans_func': self.trans_func, } def init_svd_volume(self, res, device): @@ -137,7 +140,7 @@ def compute_densityfeature(self, points): + line * mask if self.use_dwt: - plane = inverse(plane, self.dwt_level) + plane = inverse(plane, self.dwt_level, self.trans_func) plane_coef_point = F.grid_sample( plane, coordinate_plane[[idx]], @@ -173,7 +176,7 @@ def compute_appfeature(self, points): + line * mask if self.use_dwt: - plane = inverse(plane, self.dwt_level) + plane = inverse(plane, self.dwt_level, self.trans_func) plane_coef_point.append(F.grid_sample( plane, coordinate_plane[[idx]], @@ -211,7 +214,7 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target): mat_id_0, mat_id_1 = self.matMode[i] if self.use_dwt: - plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level)) + plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level, self.trans_func)) plane_coef[i] = nn.Parameter( F.interpolate(plane_coef[i].data, @@ -222,7 +225,7 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target): mode='bilinear', align_corners=True)) if self.use_dwt: - plane_coef[i].set_(forward(plane_coef[i], self.dwt_level)) + plane_coef[i].set_(forward(plane_coef[i], self.dwt_level, self.trans_func)) return plane_coef, line_coef @@ -250,9 +253,9 @@ def shrink(self, new_aabb): mode0, mode1 = self.matMode[i] if self.use_dwt: self.density_plane[i].set_(inverse(self.density_plane[i], - self.dwt_level)) + self.dwt_level, self.trans_func)) self.app_plane[i].set_(inverse(self.app_plane[i], - self.dwt_level)) + self.dwt_level, self.trans_func)) steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0] steps = int(steps / unit) * unit @@ -272,9 +275,9 @@ def shrink(self, new_aabb): if self.use_dwt: self.density_plane[i].set_(forward(self.density_plane[i], - self.dwt_level)) + self.dwt_level, self.trans_func)) self.app_plane[i].set_(forward(self.app_plane[i], - self.dwt_level)) + self.dwt_level, self.trans_func)) self.aabb = new_aabb self.update_stepSize( diff --git a/TensoRF/opt.py b/TensoRF/opt.py index ebbd5ec..853ca61 100644 --- a/TensoRF/opt.py +++ b/TensoRF/opt.py @@ -77,9 +77,10 @@ def config_parser(cmd=None): parser.add_argument("--mask_weight", type=float, default=0) parser.add_argument("--use_dwt", action='store_true') parser.add_argument("--dwt_level", type=int, default=2) + parser.add_argument("--trans_func", type=str, default='bior4.4') # Alpha mask - parser.add_argument("--alpha_offset", type=float, default=0.0, + parser.add_argument("--alpha_offset", type=float, default=1e-4, help='add to alphamask threshold') # encoding option @@ -140,6 +141,7 @@ def config_parser(cmd=None): help='N images to vis') parser.add_argument("--vis_every", type=int, default=10000, help='frequency of visualize the image') + if cmd is not None: return parser.parse_args(cmd) else: