Skip to content

Commit

Permalink
update TensoRF (mask + dwt + quantization)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Nov 4, 2022
1 parent 9b57f5b commit 8111899
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 246 deletions.
28 changes: 15 additions & 13 deletions TensoRF/configs/chair.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

dataset_name = blender
datadir = ../nerf_synthetic/chair
datadir = ../../nerf_synthetic/chair
expname = tensorf_lego_VM
basedir = ./log

Expand All @@ -15,30 +14,33 @@ update_AlphaMask_list = [2000, 4000]
N_vis = 5
vis_every = 10000

# lr_init = 0.005 # 0.001 # 0.5 # 0.02 # test
# lr_init = 0.01 # 0.001 # 0.5 # 0.02 # test
# lr_basis = 0.005 # 0.001 # 0.02 # 0.001 # test

render_test = 1

n_lamb_sigma = [16, 16, 16]
n_lamb_sh = [48, 48, 48]
n_lamb_sigma = [16, 16, 16] # 3, 3, 3] # 16, 16, 16]
n_lamb_sh = [48, 48, 48] # 6, 6, 6] # 48, 48, 48]
model_name = TensorVMSplit

shadingMode = MLP_Fea
fea2denseAct = softplus

view_pe = 2
fea_pe = 2
pos_pe = 0 # 6 # None
view_pe = 2 # 3 # 2
fea_pe = 2 # 7 # 3 # 2
featureC = 128 # 116 # 128
# data_dim_color = 64 # 22 # 8 # 2

L1_weight_inital = 0 # 8e-5
L1_weight_rest = 0 # 4e-5
rm_weight_mask_thre = 1e-4

## please uncomment following configuration if hope to training on cp model
#model_name = TensorCP
#n_lamb_sigma = [96]
#n_lamb_sh = [288]
#N_voxel_final = 125000000 # 500**3
#L1_weight_inital = 1e-5
#L1_weight_rest = 1e-5
# model_name = TensorCP
# n_lamb_sigma = [96]
# n_lamb_sh = [288]
# N_voxel_final = 125000000 # 500**3
# L1_weight_inital = 1e-5
# L1_weight_rest = 1e-5

48 changes: 48 additions & 0 deletions TensoRF/models/dwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from pytorch_wavelets import DWTInverse, DWTForward


def inverse(inputs, level=4):
assert inputs.size(-1) % 2**level == 0
assert inputs.size(-2) % 2**level == 0

res0, res1 = inputs.shape[-2:]

yl = inputs[..., :res0//(2**level), :res1//(2**level)]

yh = [
torch.stack([inputs[..., :res0//(2**(i+1)),
res1//(2**(i+1)):res1//(2**i)],
inputs[..., res0//(2**(i+1)):res0//(2**i),
:res1//(2**(i+1))],
inputs[..., res0//(2**(i+1)):res0//(2**i),
res1//(2**(i+1)):res1//(2**i)]], 2)/(level-i+1)
for i in range(level)
]

return DWTInverse(wave='bior4.4',
mode='periodization').to(inputs.device)((yl, yh))


def forward(inputs, level=4):
assert inputs.size(-1) % 2**level == 0
assert inputs.size(-2) % 2**level == 0

yl, yh = DWTForward(wave='bior4.4', J=level,
mode='periodization').to(inputs.device)(inputs)
outs = yl

for i in range(level):
cf = yh[-i-1] * (i+2)
outs = torch.cat([torch.cat([outs, cf[..., 0, :, :]], -1),
torch.cat([cf[..., 1, :, :], cf[..., 2, :, :]], -1)],
-2)
return outs


if __name__ == '__main__':
a = torch.randn(3, 5, 64, 80).cuda() * 10
print(a.shape, inverse(a).shape)
print((a - forward(inverse(a))).abs().max())
print((a - inverse(forward(a))).abs().max())

Loading

0 comments on commit 8111899

Please sign in to comment.