diff --git a/.gitignore b/.gitignore index b774bf79..e3a1470a 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,8 @@ coverage.xml .pytest_cache/ cover/ +# Slurm logs +slurm* # Translations *.mo *.pot diff --git a/configs/dreamfusion-sd-eff.yaml b/configs/dreamfusion-sd-eff.yaml new file mode 100644 index 00000000..88a23aa6 --- /dev/null +++ b/configs/dreamfusion-sd-eff.yaml @@ -0,0 +1,115 @@ +name: "dreamfusion-sd" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "eff-random-camera-datamodule" +data: + batch_size: 1 + width: 128 + height: 128 + sample_width: 64 + sample_height: 64 + camera_distance_range: [1.5, 2.0] + fovy_range: [40, 70] + elevation_range: [-10, 45] + light_sample_strategy: "dreamfusion" + eval_camera_distance: 2.0 + eval_fovy_deg: 70. + +system_type: "efficient-dreamfusion-system" +system: + geometry_type: "implicit-volume" + geometry: + radius: 2.0 + normal_type: "analytic" + + # the density initialization proposed in the DreamFusion paper + # does not work very well + # density_bias: "blob_dreamfusion" + # density_activation: exp + # density_blob_scale: 5. + # density_blob_std: 0.2 + + # use Magic3D density initialization instead + density_bias: "blob_magic3d" + density_activation: softplus + density_blob_scale: 10. + density_blob_std: 0.5 + + # coarse to fine hash grid encoding + # to ensure smooth analytic normals + pos_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + per_level_scale: 1.447269237440378 # max resolution 4096 + start_level: 8 # resolution ~200 + start_step: 2000 + update_steps: 500 + + material_type: "diffuse-with-point-light-material" + material: + ambient_only_steps: 2001 + albedo_activation: sigmoid + + background_type: "neural-environment-map-background" + background: + color_activation: sigmoid + + renderer_type: "nerf-volume-renderer" + renderer: + radius: ${system.geometry.radius} + num_samples_per_ray: 512 + + prompt_processor_type: "stable-diffusion-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + prompt: ??? + + guidance_type: "stable-diffusion-guidance" + guidance: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + guidance_scale: 100. + weighting_strategy: sds + min_step_percent: 0.02 + max_step_percent: 0.98 + + loggers: + wandb: + enable: false + project: "threestudio" + name: None + + loss: + lambda_sds: 1. + lambda_orient: [0, 10., 1000., 5000] + lambda_sparsity: 1. + lambda_opaque: 0. + + optimizer: + name: Adam + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.01 + background: + lr: 0.001 + +trainer: + max_steps: 10000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 16-mixed + +checkpoint: + save_last: true # save at each validation time + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} diff --git a/threestudio/data/__init__.py b/threestudio/data/__init__.py index ce2e5cc7..70aaeeb1 100644 --- a/threestudio/data/__init__.py +++ b/threestudio/data/__init__.py @@ -1 +1 @@ -from . import co3d, image, multiview, uncond +from . import co3d, image, multiview, uncond, uncond_eff diff --git a/threestudio/data/uncond_eff.py b/threestudio/data/uncond_eff.py new file mode 100644 index 00000000..a1ac04f7 --- /dev/null +++ b/threestudio/data/uncond_eff.py @@ -0,0 +1,441 @@ +import bisect +import math +import random +from dataclasses import dataclass, field + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + +import threestudio +from threestudio import register +from threestudio.data.uncond import RandomCameraDataset +from threestudio.utils.base import Updateable +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_device +from threestudio.utils.ops import ( + get_full_projection_matrix, + get_mvp_matrix, + get_projection_matrix, + get_ray_directions, + get_rays, + mask_ray_directions, +) +from threestudio.utils.typing import * + + +@dataclass +class EffRandomCameraDataModuleConfig: + # height, width, and batch_size should be Union[int, List[int]] + # but OmegaConf does not support Union of containers + height: Any = 128 + width: Any = 128 + sample_height: Any = 64 + sample_width: Any = 64 + batch_size: Any = 1 + resolution_milestones: List[int] = field(default_factory=lambda: []) + eval_height: int = 512 + eval_width: int = 512 + eval_batch_size: int = 1 + n_val_views: int = 1 + n_test_views: int = 120 + elevation_range: Tuple[float, float] = (-10, 90) + azimuth_range: Tuple[float, float] = (-180, 180) + camera_distance_range: Tuple[float, float] = (1, 1.5) + fovy_range: Tuple[float, float] = ( + 40, + 70, + ) # in degrees, in vertical direction (along height) + camera_perturb: float = 0.1 + center_perturb: float = 0.2 + up_perturb: float = 0.02 + light_position_perturb: float = 1.0 + light_distance_range: Tuple[float, float] = (0.8, 1.5) + eval_elevation_deg: float = 15.0 + eval_camera_distance: float = 1.5 + eval_fovy_deg: float = 70.0 + light_sample_strategy: str = "dreamfusion" + batch_uniform_azimuth: bool = True + progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy + + rays_d_normalize: bool = True + + +class EffRandomCameraIterableDataset(IterableDataset, Updateable): + def __init__(self, cfg: Any) -> None: + super().__init__() + self.cfg: EffRandomCameraDataModuleConfig = cfg + self.heights: List[int] = ( + [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height + ) + self.widths: List[int] = ( + [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width + ) + self.sample_heights: List[int] = ( + [self.cfg.sample_height] + if isinstance(self.cfg.sample_height, int) + else self.cfg.sample_height + ) + self.sample_widths: List[int] = ( + [self.cfg.sample_width] + if isinstance(self.cfg.sample_width, int) + else self.cfg.sample_width + ) + self.batch_sizes: List[int] = ( + [self.cfg.batch_size] + if isinstance(self.cfg.batch_size, int) + else self.cfg.batch_size + ) + assert ( + len(self.heights) + == len(self.widths) + == len(self.batch_sizes) + == len(self.sample_heights) + == len(self.sample_widths) + ) + self.resolution_milestones: List[int] + if ( + len(self.heights) == 1 + and len(self.widths) == 1 + and len(self.batch_sizes) == 1 + and len(self.sample_heights) == 1 + and len(self.sample_widths) == 1 + ): + if len(self.cfg.resolution_milestones) > 0: + threestudio.warn( + "Ignoring resolution_milestones since height and width are not changing" + ) + self.resolution_milestones = [-1] + else: + assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 + self.resolution_milestones = [-1] + self.cfg.resolution_milestones + + self.directions_unit_focals = [ + get_ray_directions(H=height, W=width, focal=1.0) + for (height, width) in zip(self.heights, self.widths) + ] + + self.efficiency_masks = [ + (mask_ray_directions(H, W, s_H, s_W)) + for (H, W, s_H, s_W) in zip( + self.heights, self.widths, self.sample_heights, self.sample_widths + ) + ] + self.directions_unit_focals = [ + (self.directions_unit_focals[i].view(-1, 3)[self.efficiency_masks[i]]).view( + self.sample_heights[i], self.sample_widths[i], 3 + ) + for i in range(len(self.heights)) + ] + + self.height: int = self.heights[0] + self.width: int = self.widths[0] + self.sample_height: int = self.sample_heights[0] + self.sample_width: int = self.sample_widths[0] + self.batch_size: int = self.batch_sizes[0] + self.directions_unit_focal = self.directions_unit_focals[0] + self.efficiency_mask = self.efficiency_masks[0] + self.elevation_range = self.cfg.elevation_range + self.azimuth_range = self.cfg.azimuth_range + self.camera_distance_range = self.cfg.camera_distance_range + self.fovy_range = self.cfg.fovy_range + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 + self.height = self.heights[size_ind] + self.width = self.widths[size_ind] + self.sample_height = self.sample_heights[size_ind] + self.sample_width = self.sample_widths[size_ind] + self.batch_size = self.batch_sizes[size_ind] + self.directions_unit_focal = self.directions_unit_focals[size_ind] + self.efficiency_mask = self.efficiency_masks[size_ind] + threestudio.debug( + f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}" + ) + # progressive view + self.progressive_view(global_step) + + def __iter__(self): + while True: + yield {} + + def progressive_view(self, global_step): + r = min(1.0, global_step / (self.cfg.progressive_until + 1)) + self.elevation_range = [ + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0], + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1], + ] + self.azimuth_range = [ + (1 - r) * 0.0 + r * self.cfg.azimuth_range[0], + (1 - r) * 0.0 + r * self.cfg.azimuth_range[1], + ] + # self.camera_distance_range = [ + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[0], + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[1], + # ] + # self.fovy_range = [ + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0], + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1], + # ] + + def collate(self, batch) -> Dict[str, Any]: + # sample elevation angles + elevation_deg: Float[Tensor, "B"] + elevation: Float[Tensor, "B"] + if random.random() < 0.5: + # sample elevation angles uniformly with a probability 0.5 (biased towards poles) + elevation_deg = ( + torch.rand(self.batch_size) + * (self.elevation_range[1] - self.elevation_range[0]) + + self.elevation_range[0] + ) + elevation = elevation_deg * math.pi / 180 + else: + # otherwise sample uniformly on sphere + elevation_range_percent = [ + self.elevation_range[0] / 180.0 * math.pi, + self.elevation_range[1] / 180.0 * math.pi, + ] + # inverse transform sampling + elevation = torch.asin( + ( + torch.rand(self.batch_size) + * ( + math.sin(elevation_range_percent[1]) + - math.sin(elevation_range_percent[0]) + ) + + math.sin(elevation_range_percent[0]) + ) + ) + elevation_deg = elevation / math.pi * 180.0 + + # sample azimuth angles from a uniform distribution bounded by azimuth_range + azimuth_deg: Float[Tensor, "B"] + if self.cfg.batch_uniform_azimuth: + # ensures sampled azimuth angles in a batch cover the whole range + azimuth_deg = ( + torch.rand(self.batch_size) + torch.arange(self.batch_size) + ) / self.batch_size * ( + self.azimuth_range[1] - self.azimuth_range[0] + ) + self.azimuth_range[ + 0 + ] + else: + # simple random sampling + azimuth_deg = ( + torch.rand(self.batch_size) + * (self.azimuth_range[1] - self.azimuth_range[0]) + + self.azimuth_range[0] + ) + azimuth = azimuth_deg * math.pi / 180 + + # sample distances from a uniform distribution bounded by distance_range + camera_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.camera_distance_range[1] - self.camera_distance_range[0]) + + self.camera_distance_range[0] + ) + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.batch_size, 1) + + # sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb] + camera_perturb: Float[Tensor, "B 3"] = ( + torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb + - self.cfg.camera_perturb + ) + camera_positions = camera_positions + camera_perturb + # sample center perturbations from a normal distribution with mean 0 and std center_perturb + center_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.center_perturb + ) + center = center + center_perturb + # sample up perturbations from a normal distribution with mean 0 and std up_perturb + up_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.up_perturb + ) + up = up + up_perturb + + # sample fovs from a uniform distribution bounded by fov_range + fovy_deg: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0]) + + self.fovy_range[0] + ) + fovy = fovy_deg * math.pi / 180 + + # sample light distance from a uniform distribution bounded by light_distance_range + light_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0]) + + self.cfg.light_distance_range[0] + ) + + if self.cfg.light_sample_strategy == "dreamfusion": + # sample light direction from a normal distribution with mean camera_position and std light_position_perturb + light_direction: Float[Tensor, "B 3"] = F.normalize( + camera_positions + + torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb, + dim=-1, + ) + # get light position by scaling light direction by light distance + light_positions: Float[Tensor, "B 3"] = ( + light_direction * light_distances[:, None] + ) + elif self.cfg.light_sample_strategy == "magic3d": + # sample light direction within restricted angle range (pi/3) + local_z = F.normalize(camera_positions, dim=-1) + local_x = F.normalize( + torch.stack( + [local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])], + dim=-1, + ), + dim=-1, + ) + local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1) + rot = torch.stack([local_x, local_y, local_z], dim=-1) + light_azimuth = ( + torch.rand(self.batch_size) * math.pi * 2 - math.pi + ) # [-pi, pi] + light_elevation = ( + torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6 + ) # [pi/6, pi/2] + light_positions_local = torch.stack( + [ + light_distances + * torch.cos(light_elevation) + * torch.cos(light_azimuth), + light_distances + * torch.cos(light_elevation) + * torch.sin(light_azimuth), + light_distances * torch.sin(light_elevation), + ], + dim=-1, + ) + light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0] + else: + raise ValueError( + f"Unknown light sample strategy: {self.cfg.light_sample_strategy}" + ) + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy) + directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[ + None, :, :, : + ].repeat(self.batch_size, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + + # Importance note: the returned rays_d MUST be normalized! + ### Efficiency masking added here + rays_o, rays_d = get_rays( + directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize + ) + + self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( + fovy, self.width / self.height, 0.01, 100.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) + self.fovy = fovy + + return { + "rays_o": rays_o, + "rays_d": rays_d, + "efficiency_mask": self.efficiency_mask, + "mvp_mtx": mvp_mtx, + "camera_positions": camera_positions, + "c2w": c2w, + "light_positions": light_positions, + "elevation": elevation_deg, + "azimuth": azimuth_deg, + "camera_distances": camera_distances, + "height": self.height, + "width": self.width, + "sample_height": self.sample_height, + "sample_width": self.sample_width, + "fovy": self.fovy, + "proj_mtx": self.proj_mtx, + } + + +@register("eff-random-camera-datamodule") +class EffRandomCameraDataModule(pl.LightningDataModule): + cfg: EffRandomCameraDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(EffRandomCameraDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = EffRandomCameraIterableDataset(self.cfg) + if stage in [None, "fit", "validate"]: + self.val_dataset = RandomCameraDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = RandomCameraDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: + return DataLoader( + dataset, + # very important to disable multi-processing if you want to change self attributes at runtime! + # (for example setting self.width and self.height in update_step) + num_workers=5, # type: ignore + batch_size=batch_size, + collate_fn=collate_fn, + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader( + self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate + ) + # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) + + def test_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) diff --git a/threestudio/systems/__init__.py b/threestudio/systems/__init__.py index edbe7bf2..3da7dd67 100644 --- a/threestudio/systems/__init__.py +++ b/threestudio/systems/__init__.py @@ -1,6 +1,7 @@ from . import ( control4d_multiview, dreamfusion, + eff_dreamfusion, fantasia3d, imagedreamfusion, instructnerf2nerf, diff --git a/threestudio/systems/eff_dreamfusion.py b/threestudio/systems/eff_dreamfusion.py new file mode 100644 index 00000000..bb9db2c6 --- /dev/null +++ b/threestudio/systems/eff_dreamfusion.py @@ -0,0 +1,104 @@ +from .dreamfusion import * + + +@threestudio.register("efficient-dreamfusion-system") +class EffDreamFusion(DreamFusion): + @dataclass + class Config(DreamFusion.Config): + pass + + cfg: Config + + def configure(self): + # create geometry, material, background, renderer + super().configure() + + def unmask(self, ind, subsampled_tensor, H, W): + """ + ind: B,s_H,s_W + subsampled_tensor: B,C,s_H,s_W + """ + + # Create a grid of coordinates for the original image size + offset = [ind[0, 0] % H, ind[0, 0] // H] + indices_all = torch.meshgrid( + torch.arange(W, dtype=torch.float32, device=self.device), + torch.arange(H, dtype=torch.float32, device=self.device), + indexing="xy", + ) + + grid = torch.stack( + [ + (indices_all[0] - offset[0]) * 4 / (3 * W), + (indices_all[1] - offset[1]) * 4 / (H * 3), + ], + dim=-1, + ) + grid = grid * 2 - 1 + grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1) + # Use grid_sample to upsample the subsampled tensor (B,C,H,W) + upsampled_tensor = torch.nn.functional.grid_sample( + subsampled_tensor, grid, mode="bilinear", align_corners=True + ) + + return upsampled_tensor.permute(0, 2, 3, 1) + + def training_step(self, batch, batch_idx): + out = self(batch) + ### using mask to create image at original resolution during training + (B, s_H, s_W, C) = out["comp_rgb"].shape + comp_rgb = out["comp_rgb"].permute(0, 3, 1, 2) + mask = batch["efficiency_mask"] + comp_rgb = self.unmask(mask, comp_rgb, batch["height"], batch["width"]) + # comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) + # comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) + out.update( + { + "comp_rgb": comp_rgb, + } + ) + + prompt_utils = self.prompt_processor() + guidance_out = self.guidance( + out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False + ) + + loss = 0.0 + + for name, value in guidance_out.items(): + if not (type(value) is torch.Tensor and value.numel() > 1): + self.log(f"train/{name}", value) + if name.startswith("loss_"): + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + + if self.C(self.cfg.loss.lambda_orient) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for orientation loss, no normal is found in the output." + ) + loss_orient = ( + out["weights"].detach() + * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 + ).sum() / (out["opacity"] > 0).sum() + self.log("train/loss_orient", loss_orient) + loss += loss_orient * self.C(self.cfg.loss.lambda_orient) + + loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() + self.log("train/loss_sparsity", loss_sparsity) + loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) + + opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) + loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) + self.log("train/loss_opaque", loss_opaque) + loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) + + # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/ + if "z_variance" in out and "lambda_z_variance" in self.cfg.loss: + loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() + self.log("train/loss_z_variance", loss_z_variance) + loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index 81d5b599..1e09e68e 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -217,6 +217,60 @@ def get_ray_directions( return directions +def mask_ray_directions(H: int, W: int, s_H: int, s_W: int) -> Float[Tensor, "s_H s_W"]: + """ + Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image. + pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio). + the masking is deferred to before calling get_rays(). + """ + # indices_all = torch.meshgrid( + # torch.arange(W, dtype=torch.float32) , + # torch.arange(H, dtype=torch.float32) , + # indexing="xy", + # ) + + indices_inner = torch.meshgrid( + torch.linspace(0, 0.75 * W, s_W, dtype=torch.int8), + torch.linspace(0, 0.75 * H, s_H, dtype=torch.int8), + indexing="xy", + ) + offset = [torch.randint(0, W // 8 + 1, (1,)), torch.randint(0, H // 8 + 1, (1,))] + + select_ind = indices_inner[0] + offset[0] + H * (indices_inner[1] + offset[1]) + + ### removing the random sampling approach, we sample in uniform grid + # mask = torch.zeros(H,W, dtype=torch.bool) + # mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True + + # in_ind_1d = (indices_all[0]+H*indices_all[1])[mask] + # out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)] + # ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already + # ### leads to more samples inside anyways + + # p = 0.5#(s_H*s_W)/(H*W) + # select_ind = in_ind_1d[ + # torch.multinomial( + # torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)] + # select_ind = torch.concatenate( + # [select_ind, out_ind_1d[torch.multinomial( + # torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)] + # ], + # dim=0).to(dtype=torch.int).view(s_H,s_W) + + ### first attempt at sampling, this produces variable number of rays, + ### so 4D tensor directions cant be sampled + # mask = torch.zeros(H,W, device= directions.device) + # p = (s_H*s_W)/(H*W) + # mask += p + # mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = 1 - p + # ### mask contains prob of individual pixel, drawing using Bernoulli dist + # mask = torch.bernoulli(mask).to(dtype=torch.bool) + ### postponing masking before get_rays is called + # directions = directions[mask] + + return select_ind + + def get_rays( directions: Float[Tensor, "... 3"], c2w: Float[Tensor, "... 4 4"],