Skip to content

Commit

Permalink
Implement 3DGS-MCMC in gsplat. (nerfstudio-project#238)
Browse files Browse the repository at this point in the history
* Integrate 3dgs-mcmc into gsplat. (#1)

* Initialize from random point cloud (nerfstudio-project#2)

* Clean up simple trainer changes. (nerfstudio-project#3)

* formatting

* Create dedicated mcmc train script. (nerfstudio-project#4)

* minor

* scene scale bug

* scheduler

* PR comments wip (nerfstudio-project#5)

* Some cuda changes (nerfstudio-project#6)

* data to data_ptr

* use scale_rot_to_cov3d from cuda_legacy

* quat_scale_to_covar_preci

* removed _sample_new_gaussians

* remove old prefix

* zeros_like

* empty_like

* cuda stream

* min opacity
  • Loading branch information
jefequien authored Jun 28, 2024
1 parent 0ac847d commit 4b32664
Show file tree
Hide file tree
Showing 7 changed files with 990 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
for CKPT in results/benchmark/$SCENE/ckpts/*;
for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
do
python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \
--data_dir data/360_v2/$SCENE/ \
Expand Down
44 changes: 33 additions & 11 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,20 @@ class Config:
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# Initialization strategy
init_type: str = "sfm"
# Initial number of GSs. Ignored if using sfm
init_num_pts: int = 100_000
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
init_extent: float = 3.0
# Degree of spherical harmonics
sh_degree: int = 3
# Turn on another SH degree every this steps
sh_degree_interval: int = 1000
# Initial opacity of GS
init_opa: float = 0.1
# Initial scale of GS
init_scale: float = 1.0
# Weight for SSIM loss
ssim_lambda: float = 0.2

Expand Down Expand Up @@ -150,22 +158,33 @@ def adjust_steps(self, factor: float):


def create_splats_with_optimizers(
points: Tensor, # [N, 3]
rgbs: Tensor, # [N, 3]
parser: Parser,
init_type: str = "sfm",
init_num_pts: int = 100_000,
init_extent: float = 3.0,
init_opacity: float = 0.1,
init_scale: float = 1.0,
scene_scale: float = 1.0,
sh_degree: int = 3,
init_opacity: float = 0.1,
sparse_grad: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
) -> Tuple[torch.nn.ParameterDict, torch.optim.Optimizer]:
N = points.shape[0]
if init_type == "sfm":
points = torch.from_numpy(parser.points).float()
rgbs = torch.from_numpy(parser.points_rgb / 255.0).float()
elif init_type == "random":
points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
rgbs = torch.rand((init_num_pts, 3))
else:
raise ValueError("Please specify a correct init_type: sfm or random")

N = points.shape[0]
# Initialize the GS size to be the average dist of the 3 nearest neighbors
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
dist_avg = torch.sqrt(dist2_avg)
scales = torch.log(dist_avg).unsqueeze(-1).repeat(1, 3) # [N, 3]
scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3]
quats = torch.rand((N, 4)) # [N, 4]
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]

Expand Down Expand Up @@ -249,11 +268,14 @@ def __init__(self, cfg: Config) -> None:
# Model
feature_dim = 32 if cfg.app_opt else None
self.splats, self.optimizers = create_splats_with_optimizers(
torch.from_numpy(self.parser.points).float(),
torch.from_numpy(self.parser.points_rgb / 255.0).float(),
self.parser,
init_type=cfg.init_type,
init_num_pts=cfg.init_num_pts,
init_extent=cfg.init_extent,
init_opacity=cfg.init_opa,
init_scale=cfg.init_scale,
scene_scale=self.scene_scale,
sh_degree=cfg.sh_degree,
init_opacity=cfg.init_opa,
sparse_grad=cfg.sparse_grad,
batch_size=cfg.batch_size,
feature_dim=feature_dim,
Expand Down Expand Up @@ -378,15 +400,15 @@ def train(self):
max_steps = cfg.max_steps
init_step = 0

scheulers = [
schedulers = [
# means3d has a learning rate schedule, that end at 0.01 of the initial value
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers[0], gamma=0.01 ** (1.0 / max_steps)
),
]
if cfg.pose_opt:
# pose optimization has a learning rate schedule
scheulers.append(
schedulers.append(
torch.optim.lr_scheduler.ExponentialLR(
self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps)
)
Expand Down Expand Up @@ -601,7 +623,7 @@ def train(self):
for optimizer in self.app_optimizers:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for scheduler in scheulers:
for scheduler in schedulers:
scheduler.step()

# save checkpoint
Expand Down
Loading

0 comments on commit 4b32664

Please sign in to comment.