diff --git a/TensoRF/models/tensoRF.py b/TensoRF/models/tensoRF.py index f4ec39d..72e0c18 100644 --- a/TensoRF/models/tensoRF.py +++ b/TensoRF/models/tensoRF.py @@ -10,12 +10,10 @@ def min_max_quantize(inputs, bits): if bits == 32: return inputs - # rounding - min_value = torch.amin(inputs) - max_value = torch.amax(inputs) - scale = (max_value - min_value).clamp(min=1e-8) / (2 ** bits - 1) - - rounded = torch.round((inputs - min_value) / scale) * scale + min_value + scale = torch.amax(torch.abs(inputs)).clamp(min=1e-6) + n = float(2**(bits-1) - 1) + out = torch.floor(torch.abs(inputs / scale) * n) / n * scale + rounded = out * torch.sign(inputs) return (rounded - inputs).detach() + inputs @@ -42,16 +40,16 @@ def init_svd_volume(self, res, device): @torch.no_grad() def init_mask(self): self.density_plane_mask = nn.ParameterList( - [nn.Parameter(torch.zeros_like(self.density_plane[i])) + [nn.Parameter(torch.ones_like(self.density_plane[i])) for i in range(3)]) self.density_line_mask = nn.ParameterList( - [nn.Parameter(torch.zeros_like(self.density_line[i])) + [nn.Parameter(torch.ones_like(self.density_line[i])) for i in range(3)]) self.app_plane_mask = nn.ParameterList( - [nn.Parameter(torch.zeros_like(self.app_plane[i])) + [nn.Parameter(torch.ones_like(self.app_plane[i])) for i in range(3)]) self.app_line_mask = nn.ParameterList( - [nn.Parameter(torch.zeros_like(self.app_line[i])) + [nn.Parameter(torch.ones_like(self.app_line[i])) for i in range(3)]) def init_one_svd(self, n_component, gridSize, scale, device): @@ -185,17 +183,18 @@ def up_sampling_VM(self, plane_coef, line_coef, res_target): if self.use_dwt: plane_coef[i].set_(inverse(plane_coef[i], self.dwt_level)) + plane_coef[i] = nn.Parameter( F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', align_corners=True)) - if self.use_dwt: - plane_coef[i].set_(forward(plane_coef[i], self.dwt_level)) - line_coef[i] = nn.Parameter( F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) + if self.use_dwt: + plane_coef[i].set_(forward(plane_coef[i], self.dwt_level)) + return plane_coef, line_coef @torch.no_grad() @@ -296,13 +295,11 @@ class TensorCP(TensorBase): def __init__(self, aabb, gridSize, device, **kargs): super(TensorCP, self).__init__(aabb, gridSize, device, **kargs) - def init_svd_volume(self, res, device): self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device) self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device) self.basis_mat = nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device) - def init_one_svd(self, n_component, gridSize, scale, device): line_coef = [] for i in range(len(self.vecMode)): @@ -311,7 +308,6 @@ def init_one_svd(self, n_component, gridSize, scale, device): nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1)))) return nn.ParameterList(line_coef).to(device) - def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.app_line, 'lr': lr_init_spatialxyz}, @@ -321,42 +317,35 @@ def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001 return grad_vars def compute_densityfeature(self, xyz_sampled): - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) + shape = xyz_sampled.shape[:1] + line_coef_point = 1 + for i in range(3): + line_coef_point = line_coef_point * F.grid_sample( + min_max_quantize(self.density_line[i], self.grid_bit), + coordinate_line[[i]], align_corners=True).view(-1, *shape) - line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) sigma_feature = torch.sum(line_coef_point, dim=0) - - return sigma_feature - - def compute_appfeature(self, xyz_sampled): + def compute_appfeature(self, xyz_sampled): coordinate_line = torch.stack( (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) + shape = xyz_sampled.shape[:1] + line_coef_point = 1 + for i in range(3): + line_coef_point = line_coef_point * F.grid_sample( + min_max_quantize(self.app_line[i], self.grid_bit), + coordinate_line[[i]], align_corners=True).view(-1, *shape) return self.basis_mat(line_coef_point.T) - @torch.no_grad() def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target): - for i in range(len(self.vecMode)): vec_id = self.vecMode[i] density_line_coef[i] = nn.Parameter( @@ -420,3 +409,4 @@ def TV_loss_app(self, reg): for idx in range(len(self.app_line)): total = total + reg(self.app_line[idx]) * 1e-3 return total +