Skip to content

Commit

Permalink
update mask init and QAT
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Nov 6, 2022
1 parent 034619c commit 64175fb
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions TensoRF/models/tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ def min_max_quantize(inputs, bits):
if bits == 32:
return inputs

# rounding
min_value = torch.amin(inputs)
max_value = torch.amax(inputs)
scale = (max_value - min_value).clamp(min=1e-8) / (2 ** bits - 1)

rounded = torch.round((inputs - min_value) / scale) * scale + min_value
scale = torch.amax(torch.abs(inputs)).clamp(min=1e-6)
n = float(2**(bits-1) - 1)
out = torch.floor(torch.abs(inputs / scale) * n) / n * scale
rounded = out * torch.sign(inputs)

return (rounded - inputs).detach() + inputs

Expand All @@ -42,16 +40,16 @@ def init_svd_volume(self, res, device):
@torch.no_grad()
def init_mask(self):
self.density_plane_mask = nn.ParameterList(
[nn.Parameter(torch.zeros_like(self.density_plane[i]))
[nn.Parameter(torch.ones_like(self.density_plane[i]))
for i in range(3)])
self.density_line_mask = nn.ParameterList(
[nn.Parameter(torch.zeros_like(self.density_line[i]))
[nn.Parameter(torch.ones_like(self.density_line[i]))
for i in range(3)])
self.app_plane_mask = nn.ParameterList(
[nn.Parameter(torch.zeros_like(self.app_plane[i]))
[nn.Parameter(torch.ones_like(self.app_plane[i]))
for i in range(3)])
self.app_line_mask = nn.ParameterList(
[nn.Parameter(torch.zeros_like(self.app_line[i]))
[nn.Parameter(torch.ones_like(self.app_line[i]))
for i in range(3)])

def init_one_svd(self, n_component, gridSize, scale, device):
Expand Down Expand Up @@ -185,17 +183,18 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target):

if self.use_dwt:
plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level))

plane_coef[i] = nn.Parameter(
F.interpolate(plane_coef[i].data,
size=(res_target[mat_id_1], res_target[mat_id_0]),
mode='bilinear', align_corners=True))
if self.use_dwt:
plane_coef[i].set_(forward(plane_coef[i], self.dwt_level))

line_coef[i] = nn.Parameter(
F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1),
mode='bilinear', align_corners=True))

if self.use_dwt:
plane_coef[i].set_(forward(plane_coef[i], self.dwt_level))

return plane_coef, line_coef

@torch.no_grad()
Expand Down Expand Up @@ -296,13 +295,11 @@ class TensorCP(TensorBase):
def __init__(self, aabb, gridSize, device, **kargs):
super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)


def init_svd_volume(self, res, device):
self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device)
self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device)
self.basis_mat = nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device)


def init_one_svd(self, n_component, gridSize, scale, device):
line_coef = []
for i in range(len(self.vecMode)):
Expand All @@ -311,7 +308,6 @@ def init_one_svd(self, n_component, gridSize, scale, device):
nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1))))
return nn.ParameterList(line_coef).to(device)


def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
{'params': self.app_line, 'lr': lr_init_spatialxyz},
Expand All @@ -321,42 +317,35 @@ def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001
return grad_vars

def compute_densityfeature(self, xyz_sampled):

coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)

shape = xyz_sampled.shape[:1]
line_coef_point = 1
for i in range(3):
line_coef_point = line_coef_point * F.grid_sample(
min_max_quantize(self.density_line[i], self.grid_bit),
coordinate_line[[i]], align_corners=True).view(-1, *shape)

line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
sigma_feature = torch.sum(line_coef_point, dim=0)


return sigma_feature

def compute_appfeature(self, xyz_sampled):

def compute_appfeature(self, xyz_sampled):
coordinate_line = torch.stack(
(xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)


line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
shape = xyz_sampled.shape[:1]
line_coef_point = 1
for i in range(3):
line_coef_point = line_coef_point * F.grid_sample(
min_max_quantize(self.app_line[i], self.grid_bit),
coordinate_line[[i]], align_corners=True).view(-1, *shape)

return self.basis_mat(line_coef_point.T)


@torch.no_grad()
def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target):

for i in range(len(self.vecMode)):
vec_id = self.vecMode[i]
density_line_coef[i] = nn.Parameter(
Expand Down Expand Up @@ -420,3 +409,4 @@ def TV_loss_app(self, reg):
for idx in range(len(self.app_line)):
total = total + reg(self.app_line[idx]) * 1e-3
return total

0 comments on commit 64175fb

Please sign in to comment.