Skip to content

Commit

Permalink
fix wrong implementation (TensoRF)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Oct 21, 2022
1 parent 082959a commit aec802b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def render(self, rays):
rays: [..., 6] shaped tensor (3 for origin, 3 for direction)
"""
rays_o, rays_d = rays[..., :3], rays[..., 3:] # origins, viewdirs

z_vals = torch.linspace(self.near, self.far,
self.n_samples_per_ray).unsqueeze(0).to(rays_o)

Expand Down Expand Up @@ -142,7 +143,7 @@ def render(self, rays):
self.main_net(pts[valid_rays])) * self.density_scale

# alpha & weights
alpha = 1. - torch.exp(-sigma * dists)
alpha = 1. - torch.exp(-sigma * dists * self.density_scale)
weights = alpha * torch.cumprod(F.pad(1 - alpha + 1e-10, [1, 0],
value=1),
-1)[..., :-1] # exclusive cumprod
Expand Down
12 changes: 9 additions & 3 deletions models/grid_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def __init__(self, resolution: int, n_chan: int, out_dim=1):
1e-1 * torch.randn(3, n_chan, 1, resolution),
requires_grad=True)

self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False)
if out_dim > 1:
self.basis_mat = nn.Linear(n_chan*3, out_dim, bias=False)
else:
self.basis_mat = None

def forward(self, coords, *args, **kwargs):
# [B, 3] to [1, B, 1, 3]
Expand All @@ -188,14 +191,17 @@ def forward(self, coords, *args, **kwargs):
feats1 = F.grid_sample(grid, F.pad(torch.cat([coords[..., (0,)],
coords[..., (1,)],
coords[..., (2,)]], 0),
(1, 0)),
(0, 1)),
mode='bilinear',
padding_mode='zeros', align_corners=True)
feats1 = feats1.squeeze(-1).permute(2, 1, 0) # [B, C, 3]

feats0 = (feats0 * feats1).flatten(1, -1) # [B, C*3]
del feats1
return self.basis_mat(feats0).squeeze(-1)

if self.basis_mat is not None:
return self.basis_mat(feats0).squeeze(-1)
return feats0.sum(dim=-1)

def compute_tv(self):
return F.mse_loss(self.planes[..., 1:, :], self.planes[..., :-1, :]) \
Expand Down

0 comments on commit aec802b

Please sign in to comment.