Skip to content

Commit

Permalink
Merge pull request #2 from threestudio-project/saurus/effsd
Browse files Browse the repository at this point in the history
Fix format
  • Loading branch information
jadevaibhav authored Oct 2, 2024
2 parents 5f4f664 + 083e397 commit bc8af2a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 56 deletions.
40 changes: 26 additions & 14 deletions threestudio/data/uncond_eff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import threestudio
from threestudio import register
from threestudio.data.uncond import RandomCameraDataset
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_device
Expand All @@ -20,10 +21,10 @@
get_projection_matrix,
get_ray_directions,
get_rays,
mask_ray_directions
mask_ray_directions,
)
from threestudio.utils.typing import *
from threestudio.data.uncond import RandomCameraDataset


@dataclass
class EffRandomCameraDataModuleConfig:
Expand Down Expand Up @@ -73,17 +74,27 @@ def __init__(self, cfg: Any) -> None:
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
self.sample_heights: List[int] = (
[self.cfg.sample_height] if isinstance(self.cfg.sample_height, int) else self.cfg.sample_height
[self.cfg.sample_height]
if isinstance(self.cfg.sample_height, int)
else self.cfg.sample_height
)
self.sample_widths: List[int] = (
[self.cfg.sample_width] if isinstance(self.cfg.sample_width, int) else self.cfg.sample_width
[self.cfg.sample_width]
if isinstance(self.cfg.sample_width, int)
else self.cfg.sample_width
)
self.batch_sizes: List[int] = (
[self.cfg.batch_size]
if isinstance(self.cfg.batch_size, int)
else self.cfg.batch_size
)
assert len(self.heights) == len(self.widths) == len(self.batch_sizes) == len(self.sample_heights) == len(self.sample_widths)
assert (
len(self.heights)
== len(self.widths)
== len(self.batch_sizes)
== len(self.sample_heights)
== len(self.sample_widths)
)
self.resolution_milestones: List[int]
if (
len(self.heights) == 1
Expand All @@ -107,16 +118,18 @@ def __init__(self, cfg: Any) -> None:
]

self.efficiency_masks = [
(mask_ray_directions(H,W,s_H,s_W)) for (H,W,s_H,s_W)
in zip( self.heights, self.widths,
self.sample_heights, self.sample_widths)]
(mask_ray_directions(H, W, s_H, s_W))
for (H, W, s_H, s_W) in zip(
self.heights, self.widths, self.sample_heights, self.sample_widths
)
]
self.directions_unit_focals = [
(
self.directions_unit_focals[i].view(-1,3)[self.efficiency_masks[i]]
).view(self.sample_heights[i],self.sample_widths[i],3)
(self.directions_unit_focals[i].view(-1, 3)[self.efficiency_masks[i]]).view(
self.sample_heights[i], self.sample_widths[i], 3
)
for i in range(len(self.heights))
]

self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.sample_height: int = self.sample_heights[0]
Expand Down Expand Up @@ -360,7 +373,7 @@ def collate(self, batch) -> Dict[str, Any]:
return {
"rays_o": rays_o,
"rays_d": rays_d,
"efficiency_mask":self.efficiency_mask,
"efficiency_mask": self.efficiency_mask,
"mvp_mtx": mvp_mtx,
"camera_positions": camera_positions,
"c2w": c2w,
Expand All @@ -377,7 +390,6 @@ def collate(self, batch) -> Dict[str, Any]:
}



@register("eff-random-camera-datamodule")
class EffRandomCameraDataModule(pl.LightningDataModule):
cfg: EffRandomCameraDataModuleConfig
Expand Down
2 changes: 0 additions & 2 deletions threestudio/systems/dreamfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,3 @@ def on_test_epoch_end(self):
name="test",
step=self.true_global_step,
)


48 changes: 27 additions & 21 deletions threestudio/systems/eff_dreamfusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dreamfusion import *


@threestudio.register("efficient-dreamfusion-system")
class EffDreamFusion(DreamFusion):
@dataclass
Expand All @@ -12,45 +13,50 @@ def configure(self):
# create geometry, material, background, renderer
super().configure()

def unmask(self,ind,subsampled_tensor,H,W):
def unmask(self, ind, subsampled_tensor, H, W):
"""
ind: B,s_H,s_W
subsampled_tensor: B,C,s_H,s_W
"""

# Create a grid of coordinates for the original image size
offset = [ind[0,0]%H,ind[0,0]//H]
offset = [ind[0, 0] % H, ind[0, 0] // H]
indices_all = torch.meshgrid(
torch.arange(W, dtype=torch.float32,device=self.device) ,
torch.arange(H, dtype=torch.float32,device=self.device) ,
indexing="xy"
)
torch.arange(W, dtype=torch.float32, device=self.device),
torch.arange(H, dtype=torch.float32, device=self.device),
indexing="xy",
)

grid = torch.stack(
[(indices_all[0] - offset[0])*4/(3*W),
(indices_all[1] - offset[1])*4/(H*3)],
dim=-1)
grid = grid*2 - 1
[
(indices_all[0] - offset[0]) * 4 / (3 * W),
(indices_all[1] - offset[1]) * 4 / (H * 3),
],
dim=-1,
)
grid = grid * 2 - 1
grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1)
# Use grid_sample to upsample the subsampled tensor (B,C,H,W)
upsampled_tensor = torch.nn.functional.grid_sample(subsampled_tensor, grid, mode='bilinear', align_corners=True)
upsampled_tensor = torch.nn.functional.grid_sample(
subsampled_tensor, grid, mode="bilinear", align_corners=True
)

return upsampled_tensor.permute(0,2,3,1)
return upsampled_tensor.permute(0, 2, 3, 1)

def training_step(self, batch, batch_idx):
out = self(batch)
### using mask to create image at original resolution during training
(B,s_H,s_W,C) = out["comp_rgb"].shape
comp_rgb = out["comp_rgb"].permute(0,3,1,2)
(B, s_H, s_W, C) = out["comp_rgb"].shape
comp_rgb = out["comp_rgb"].permute(0, 3, 1, 2)
mask = batch["efficiency_mask"]
comp_rgb = self.unmask(mask,comp_rgb,batch["height"],batch["width"])
comp_rgb = self.unmask(mask, comp_rgb, batch["height"], batch["width"])
# comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C)
# comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C)
out.update(
{
"comp_rgb": comp_rgb,
}
)
{
"comp_rgb": comp_rgb,
}
)

prompt_utils = self.prompt_processor()
guidance_out = self.guidance(
Expand Down Expand Up @@ -95,4 +101,4 @@ def training_step(self, batch, batch_idx):
for name, value in self.cfg.loss.items():
self.log(f"train_params/{name}", self.C(value))

return {"loss": loss}
return {"loss": loss}
33 changes: 14 additions & 19 deletions threestudio/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,8 @@ def get_ray_directions(

return directions

def mask_ray_directions(
H: int,
W:int,
s_H:int,
s_W:int
) -> Float[Tensor, "s_H s_W"]:

def mask_ray_directions(H: int, W: int, s_H: int, s_W: int) -> Float[Tensor, "s_H s_W"]:
"""
Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image.
pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio).
Expand All @@ -233,26 +229,24 @@ def mask_ray_directions(
# indexing="xy",
# )

indices_inner = torch.meshgrid(
torch.linspace(0,0.75*W,s_W, dtype=torch.int8) ,
torch.linspace(0,0.75*H,s_H, dtype=torch.int8) ,
indices_inner = torch.meshgrid(
torch.linspace(0, 0.75 * W, s_W, dtype=torch.int8),
torch.linspace(0, 0.75 * H, s_H, dtype=torch.int8),
indexing="xy",
)
offset = [torch.randint(0,W//8 +1,(1,)),
torch.randint(0,H//8 +1,(1,))]

select_ind = indices_inner[0]+offset[0] + H*(indices_inner[1] + offset[1])

offset = [torch.randint(0, W // 8 + 1, (1,)), torch.randint(0, H // 8 + 1, (1,))]

select_ind = indices_inner[0] + offset[0] + H * (indices_inner[1] + offset[1])

### removing the random sampling approach, we sample in uniform grid
# mask = torch.zeros(H,W, dtype=torch.bool)
# mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

# in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
# out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
# ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already
# ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already
# ### leads to more samples inside anyways

# p = 0.5#(s_H*s_W)/(H*W)
# select_ind = in_ind_1d[
# torch.multinomial(
Expand All @@ -263,19 +257,20 @@ def mask_ray_directions(
# ],
# dim=0).to(dtype=torch.int).view(s_H,s_W)

### first attempt at sampling, this produces variable number of rays,
### first attempt at sampling, this produces variable number of rays,
### so 4D tensor directions cant be sampled
# mask = torch.zeros(H,W, device= directions.device)
# p = (s_H*s_W)/(H*W)
# mask += p
# mask += p
# mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = 1 - p
# ### mask contains prob of individual pixel, drawing using Bernoulli dist
# mask = torch.bernoulli(mask).to(dtype=torch.bool)
### postponing masking before get_rays is called
#directions = directions[mask]
# directions = directions[mask]

return select_ind


def get_rays(
directions: Float[Tensor, "... 3"],
c2w: Float[Tensor, "... 4 4"],
Expand Down

0 comments on commit bc8af2a

Please sign in to comment.