Skip to content

Commit

Permalink
accelerating train time
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 28, 2022
1 parent b6a3d93 commit a45ff8a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
23 changes: 13 additions & 10 deletions PREF/models/phasoBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self, aabb, gridSize, device,

self.near_far = near_far
self.step_ratio = step_ratio
self.axis = [torch.arange(d).to(self.device).to(torch.float) for d in self.den_num_comp]
self.axis = [torch.arange(d, dtype=torch.float32, device=self.device)
for d in self.den_num_comp]

self.update_stepSize(gridSize)
self.init_phasor_volume(gridSize[0], device)
Expand Down Expand Up @@ -173,9 +174,10 @@ def load(self, ckpt):
def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples > 0 else self.nSamples
near, far = self.near_far
interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
interpx = torch.linspace(near, far, N_samples, device=rays_o.device) \
.unsqueeze(0)
if is_train:
interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
interpx += torch.rand_like(interpx) * ((far - near) / N_samples)

rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
Expand All @@ -190,11 +192,12 @@ def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
rate_b = (self.aabb[0] - rays_o) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)

rng = torch.arange(N_samples)[None].float()
rng = torch.arange(N_samples, dtype=torch.float32,
device=rays_o.device)[None]
if is_train:
rng = rng.repeat(rays_d.shape[-2],1)
rng += torch.rand_like(rng[:,[0]])
step = stepsize * rng.to(rays_o.device)
step = stepsize * rng
interpx = (t_min[...,None] + step)

rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
Expand All @@ -211,10 +214,10 @@ def getDenseAlpha(self,gridSize=None):
gridSize = self.gridSize if gridSize is None else gridSize

samples = torch.stack(torch.meshgrid(
torch.linspace(0, 1, gridSize[0]),
torch.linspace(0, 1, gridSize[1]),
torch.linspace(0, 1, gridSize[2]),
), -1).to(self.device)
torch.linspace(0, 1, gridSize[0], device=self.device),
torch.linspace(0, 1, gridSize[1], device=self.device),
torch.linspace(0, 1, gridSize[2], device=self.device),
), -1)
dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples

# dense_xyz = dense_xyz
Expand Down Expand Up @@ -376,4 +379,4 @@ def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_sa
# assert (normal_map.amin(1) >= 0).all()
return rgb_map, depth_map, normal_map

return rgb_map, depth_map
return rgb_map, depth_map
8 changes: 6 additions & 2 deletions PREF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def __init__(self, total, batch):
self.ids = None

def nextids(self):
self.curr+=self.batch
self.curr += self.batch
if self.curr + self.batch > self.total:
self.ids = torch.LongTensor(np.random.permutation(self.total))
self.ids = torch.randperm(self.total).cuda()
self.curr = 0
return self.ids[self.curr:self.curr+self.batch]

Expand Down Expand Up @@ -179,6 +179,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
if not args.ndc_ray:
allrays, allrgbs = phasorf.filtering_rays(allrays, allrgbs, bbox_only=True)
allrays = allrays.cuda()
allrgbs = allrgbs.cuda()
trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)


Expand Down Expand Up @@ -266,6 +268,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
# filter rays outside the bbox
allrays,allrgbs = phasorf.filtering_rays(allrays,allrgbs)
trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
allrays = allrays.cuda()
allrgbs = allrgbs.cuda()

# TODO:
if upsamp_list is not None and iteration in upsamp_list:
Expand Down

0 comments on commit a45ff8a

Please sign in to comment.