Skip to content

Commit

Permalink
Add modules of 4d-fy for 4D generation(#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
DSaurus authored Dec 6, 2023
1 parent 2c20227 commit 3fe3153
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
16 changes: 16 additions & 0 deletions threestudio/models/geometry/implicit_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Config(BaseImplicitGeometry.Config):
# automatically determine the threshold
isosurface_threshold: Union[float, str] = 25.0

# 4D Gaussian Annealing
anneal_density_blob_std_config: Optional[dict] = None

cfg: Config

def configure(self) -> None:
Expand Down Expand Up @@ -267,3 +270,16 @@ def create_from(
raise TypeError(
f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}"
)

def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
if self.cfg.anneal_density_blob_std_config is not None:
min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step
max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step
if global_step >= min_step and global_step <= max_step:
end_val = self.cfg.anneal_density_blob_std_config.end_val
start_val = self.cfg.anneal_density_blob_std_config.start_val
self.density_blob_std = start_val + (global_step - min_step) * (
end_val - start_val
) / (max_step - min_step)
64 changes: 64 additions & 0 deletions threestudio/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,68 @@ def forward(self, x):
return self.encoding(x)


# 4D implicit decomposition of space and time (4D-fy)
class TCNNEncodingSpatialTime(nn.Module):
def __init__(
self, in_channels, config, dtype=torch.float32, init_time_zero=False
) -> None:
super().__init__()
self.n_input_dims = in_channels
config["otype"] = "HashGrid"
self.num_frames = 1 # config["num_frames"]
self.static = config["static"]
self.cfg = config_to_primitive(config)
self.cfg_time = self.cfg
self.n_key_frames = config.get("n_key_frames", 1)
with torch.cuda.device(get_rank()):
self.encoding = tcnn.Encoding(self.n_input_dims, self.cfg, dtype=dtype)
self.encoding_time = tcnn.Encoding(
self.n_input_dims + 1, self.cfg_time, dtype=dtype
)
self.n_output_dims = self.encoding.n_output_dims
self.frame_time = None
if self.static:
self.set_temp_param_grad(requires_grad=False)
self.use_key_frame = config.get("use_key_frame", False)
self.is_video = True
self.update_occ_grid = False

def set_temp_param_grad(self, requires_grad=False):
self.set_param_grad(self.encoding_time, requires_grad=requires_grad)

def set_param_grad(self, param_list, requires_grad=False):
if isinstance(param_list, nn.Parameter):
param_list.requires_grad = requires_grad
else:
for param in param_list.parameters():
param.requires_grad = requires_grad

def forward(self, x):
# TODO frame_time only supports batch_size == 1 cases
if self.update_occ_grid and not isinstance(self.frame_time, float):
frame_time = self.frame_time
else:
if (self.static or not self.training) and self.frame_time is None:
frame_time = torch.zeros(
(self.num_frames, 1), device=x.device, dtype=x.dtype
).expand(x.shape[0], 1)
else:
if self.frame_time is None:
frame_time = 0.0
else:
frame_time = self.frame_time
frame_time = (
torch.ones((self.num_frames, 1), device=x.device, dtype=x.dtype)
* frame_time
).expand(x.shape[0], 1)
frame_time = frame_time.view(-1, 1)
enc_space = self.encoding(x)
x_frame_time = torch.cat((x, frame_time), 1)
enc_space_time = self.encoding_time(x_frame_time)
enc = enc_space + enc_space_time
return enc


class ProgressiveBandHashGrid(nn.Module, Updateable):
def __init__(self, in_channels, config, dtype=torch.float32):
super().__init__()
Expand Down Expand Up @@ -136,6 +198,8 @@ def get_encoding(n_input_dims: int, config) -> nn.Module:
encoding = ProgressiveBandFrequency(n_input_dims, config_to_primitive(config))
elif config.otype == "ProgressiveBandHashGrid":
encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config))
elif config.otype == "HashGridSpatialTime":
encoding = TCNNEncodingSpatialTime(n_input_dims, config) # 4D-fy encoding
else:
encoding = TCNNEncoding(n_input_dims, config_to_primitive(config))
encoding = CompositeEncoding(
Expand Down

0 comments on commit 3fe3153

Please sign in to comment.