Skip to content

Commit

Permalink
update trans_func
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Nov 11, 2022
1 parent 3934256 commit 238a55b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 17 deletions.
66 changes: 60 additions & 6 deletions TensoRF/models/dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def split2d(inputs, level=4):
inputs[..., res0//(2**(i+1)):res0//(2**i),
:res1//(2**(i+1))],
inputs[..., res0//(2**(i+1)):res0//(2**i),
res1//(2**(i+1)):res1//(2**i)]], 2) # /(level-i+1)
res1//(2**(i+1)):res1//(2**i)]], 2)/(level-i+1)
for i in range(level)
]

Expand All @@ -36,21 +36,26 @@ def split1d(inputs, level=4):


# inverse and forward
def inverse(inputs, level=4):
return DWTInverse(wave='bior4.4', mode='periodization').to(inputs.device)\
def inverse(inputs, level=4, trans_func='bior4.4'):
if trans_func == 'cosine':
return idctn(inputs, (-2, -1))
return DWTInverse(wave=trans_func, mode='periodization').to(inputs.device)\
(split2d(inputs, level))


def forward(inputs, level=4):
def forward(inputs, level=4, trans_func='bior4.4'):
assert inputs.size(-1) % 2**level == 0
assert inputs.size(-2) % 2**level == 0

yl, yh = DWTForward(wave='bior4.4', J=level,
if trans_func == 'cosine':
return dctn(inputs, (-2, -1))

yl, yh = DWTForward(wave=trans_func, J=level,
mode='periodization').to(inputs.device)(inputs)
outs = yl

for i in range(level):
cf = yh[-i-1] # * (i+2)
cf = yh[-i-1] * (i+ 2)
outs = torch.cat([torch.cat([outs, cf[..., 0, :, :]], -1),
torch.cat([cf[..., 1, :, :], cf[..., 2, :, :]], -1)],
-2)
Expand All @@ -75,6 +80,55 @@ def forward1d(inputs, level=4):
return outs


def dct(coefs, coords=None):
'''
coefs: [..., C] # C: n_coefs
coords: [..., S] # S: n_samples
'''
if coords is None:
coords = torch.ones_like(coefs) \
* torch.arange(coefs.size(-1)).to(coefs.device) # \
cos = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5) / coefs.size(-1)
* (torch.arange(coefs.size(-1)).to(coefs.device) + 0.5))
return torch.einsum('...C,...SC->...S', coefs*(2/coefs.size(-1))**0.5, cos)


def dctn(coefs, axes=None):
if axes is None:
axes = tuple(range(len(coefs.shape)))
out = coefs
for ax in axes:
out = out.transpose(-1, ax)
out = dct(out)
out = out.transpose(-1, ax)
return out


def idctn(coefs, axes=None, n_out=None, **kwargs):
if axes is None:
axes = tuple(range(len(coefs.shape)))

if n_out is None or isinstance(n_out, int):
n_out = [n_out] * len(axes)

out = coefs
for ax, n_o in zip(axes, n_out):
out = out.transpose(-1, ax)
out = idct(out, n_o, **kwargs)
out = out.transpose(-1, ax)
return out


def idct(coefs, n_out=None):
N = coefs.size(-1)
if n_out is None:
n_out = N
# TYPE IV
out = torch.cos(torch.pi * (torch.arange(N).to(coefs.device) + 0.5) / N
* (torch.linspace(0, N-1, n_out).unsqueeze(-1).to(coefs.device) + 0.5))
return torch.einsum('...C,...SC->...S', coefs*(2/N)**0.5, out)


if __name__ == '__main__':
a = torch.randn(3, 5, 64, 64).cuda() * 10
print(a.shape, inverse(a).shape)
Expand Down
23 changes: 13 additions & 10 deletions TensoRF/models/tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def min_max_quantize(inputs, bits):

class TensorVMSplit(TensorBase):
def __init__(self, aabb, gridSize, device,
use_mask=False, use_dwt=False, dwt_level=2, **kargs):
use_mask=False, use_dwt=False, dwt_level=2,
trans_func='bior4.4', **kargs):
super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)
self.use_mask = use_mask
self.use_dwt = use_dwt
self.dwt_level = dwt_level
self.trans_func = trans_func

if use_mask:
self.init_mask()
Expand Down Expand Up @@ -55,7 +57,8 @@ def get_kwargs(self):
'grid_bit': self.grid_bit,
'use_mask': self.use_mask,
'use_dwt': self.use_dwt,
'dwt_level': self.dwt_level
'dwt_level': self.dwt_level,
'trans_func': self.trans_func,
}

def init_svd_volume(self, res, device):
Expand Down Expand Up @@ -137,7 +140,7 @@ def compute_densityfeature(self, points):
+ line * mask

if self.use_dwt:
plane = inverse(plane, self.dwt_level)
plane = inverse(plane, self.dwt_level, self.trans_func)

plane_coef_point = F.grid_sample(
plane, coordinate_plane[[idx]],
Expand Down Expand Up @@ -173,7 +176,7 @@ def compute_appfeature(self, points):
+ line * mask

if self.use_dwt:
plane = inverse(plane, self.dwt_level)
plane = inverse(plane, self.dwt_level, self.trans_func)

plane_coef_point.append(F.grid_sample(
plane, coordinate_plane[[idx]],
Expand Down Expand Up @@ -211,7 +214,7 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target):
mat_id_0, mat_id_1 = self.matMode[i]

if self.use_dwt:
plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level))
plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level, self.trans_func))

plane_coef[i] = nn.Parameter(
F.interpolate(plane_coef[i].data,
Expand All @@ -222,7 +225,7 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target):
mode='bilinear', align_corners=True))

if self.use_dwt:
plane_coef[i].set_(forward(plane_coef[i], self.dwt_level))
plane_coef[i].set_(forward(plane_coef[i], self.dwt_level, self.trans_func))

return plane_coef, line_coef

Expand Down Expand Up @@ -250,9 +253,9 @@ def shrink(self, new_aabb):
mode0, mode1 = self.matMode[i]
if self.use_dwt:
self.density_plane[i].set_(inverse(self.density_plane[i],
self.dwt_level))
self.dwt_level, self.trans_func))
self.app_plane[i].set_(inverse(self.app_plane[i],
self.dwt_level))
self.dwt_level, self.trans_func))

steps = (new_aabb[1][mode0]-new_aabb[0][mode0]) / self.units[mode0]
steps = int(steps / unit) * unit
Expand All @@ -272,9 +275,9 @@ def shrink(self, new_aabb):

if self.use_dwt:
self.density_plane[i].set_(forward(self.density_plane[i],
self.dwt_level))
self.dwt_level, self.trans_func))
self.app_plane[i].set_(forward(self.app_plane[i],
self.dwt_level))
self.dwt_level, self.trans_func))

self.aabb = new_aabb
self.update_stepSize(
Expand Down
4 changes: 3 additions & 1 deletion TensoRF/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def config_parser(cmd=None):
parser.add_argument("--mask_weight", type=float, default=0)
parser.add_argument("--use_dwt", action='store_true')
parser.add_argument("--dwt_level", type=int, default=2)
parser.add_argument("--trans_func", type=str, default='bior4.4')

# Alpha mask
parser.add_argument("--alpha_offset", type=float, default=0.0,
parser.add_argument("--alpha_offset", type=float, default=1e-4,
help='add to alphamask threshold')

# encoding option
Expand Down Expand Up @@ -140,6 +141,7 @@ def config_parser(cmd=None):
help='N images to vis')
parser.add_argument("--vis_every", type=int, default=10000,
help='frequency of visualize the image')

if cmd is not None:
return parser.parse_args(cmd)
else:
Expand Down

0 comments on commit 238a55b

Please sign in to comment.