diff --git a/models/__init__.py b/models/__init__.py index 8cbe384..8271e1f 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -104,6 +104,7 @@ def render(self, rays): rays: [..., 6] shaped tensor (3 for origin, 3 for direction) """ rays_o, rays_d = rays[..., :3], rays[..., 3:] # origins, viewdirs + z_vals = torch.linspace(self.near, self.far, self.n_samples_per_ray).unsqueeze(0).to(rays_o) @@ -142,7 +143,7 @@ def render(self, rays): self.main_net(pts[valid_rays])) * self.density_scale # alpha & weights - alpha = 1. - torch.exp(-sigma * dists) + alpha = 1. - torch.exp(-sigma * dists * self.density_scale) weights = alpha * torch.cumprod(F.pad(1 - alpha + 1e-10, [1, 0], value=1), -1)[..., :-1] # exclusive cumprod diff --git a/models/grid_based.py b/models/grid_based.py index 3900ae3..ff787ce 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -168,7 +168,10 @@ def __init__(self, resolution: int, n_chan: int, out_dim=1): 1e-1 * torch.randn(3, n_chan, 1, resolution), requires_grad=True) - self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False) + if out_dim > 1: + self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False) + else: + self.basis_mat = None def forward(self, coords, *args, **kwargs): # [B, 3] to [1, B, 1, 3] @@ -188,14 +191,17 @@ def forward(self, coords, *args, **kwargs): feats1 = F.grid_sample(grid, F.pad(torch.cat([coords[..., (0,)], coords[..., (1,)], coords[..., (2,)]], 0), - (1, 0)), + (0, 1)), mode='bilinear', padding_mode='zeros', align_corners=True) feats1 = feats1.squeeze(-1).permute(2, 1, 0) # [B, C, 3] feats0 = (feats0 * feats1).flatten(1, -1) # [B, C*3] del feats1 - return self.basis_mat(feats0).squeeze(-1) + + if self.basis_mat is not None: + return self.basis_mat(feats0).squeeze(-1) + return feats0.sum(dim=-1) def compute_tv(self): return F.mse_loss(self.planes[..., 1:, :], self.planes[..., :-1, :]) \