Skip to content

Commit

Permalink
remove get_module from models.modules.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 20, 2022
1 parent bf2571c commit 475975e
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,6 @@ def forward(self, inputs):


""" MODULES """
def get_module(shadingMode, in_dim, pos_pe, view_pe, fea_pe, hidden_dim):
if shadingMode == 'MLP_PE':
return MLP(in_dim, include_pos=True, include_view=True,
pos_n_freq=pos_pe, view_n_freq=view_pe,
hidden_dim=hidden_dim)
elif shadingMode == 'MLP_Fea':
return MLP(in_dim, include_view=True, feat_n_freq=fea_pe,
view_n_freq=view_pe, hidden_dim=hidden_dim)
elif shadingMode == 'MLP':
return MLP(in_dim, include_view=True, view_n_freq=view_pe,
hidden_dim=hidden_dim)
elif shadingMode == 'SH':
return SHRender
elif shadingMode == 'RGB':
assert in_dim == 3
return RGBRender
else:
raise ValueError(f"Unrecognized shading module: {shadingMode}")


# modules
class MLP(nn.Module):
def __init__(self, feat_dim, out_dim=3,
include_feat=True, include_pos=False, include_view=False,
Expand Down Expand Up @@ -142,7 +121,6 @@ def forward(self, inputs):
return torch.cat(outs, -1)


# utils
def positional_encoding(positions, freqs):
freq_bands = (torch.pi*2**torch.arange(freqs).float()).to(positions.device)
pts = (positions[..., None] * freq_bands).reshape(*positions.shape[:-1], -1)
Expand Down

0 comments on commit 475975e

Please sign in to comment.