diff --git a/PREF/models/phasoBase.py b/PREF/models/phasoBase.py index e5c95b5..9e7c4d6 100644 --- a/PREF/models/phasoBase.py +++ b/PREF/models/phasoBase.py @@ -66,7 +66,8 @@ def __init__(self, aabb, gridSize, device, self.near_far = near_far self.step_ratio = step_ratio - self.axis = [torch.arange(d).to(self.device).to(torch.float) for d in self.den_num_comp] + self.axis = [torch.arange(d, dtype=torch.float32, device=self.device) + for d in self.den_num_comp] self.update_stepSize(gridSize) self.init_phasor_volume(gridSize[0], device) @@ -173,9 +174,10 @@ def load(self, ckpt): def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): N_samples = N_samples if N_samples > 0 else self.nSamples near, far = self.near_far - interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) + interpx = torch.linspace(near, far, N_samples, device=rays_o.device) \ + .unsqueeze(0) if is_train: - interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) + interpx += torch.rand_like(interpx) * ((far - near) / N_samples) rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) @@ -190,11 +192,12 @@ def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): rate_b = (self.aabb[0] - rays_o) / vec t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) - rng = torch.arange(N_samples)[None].float() + rng = torch.arange(N_samples, dtype=torch.float32, + device=rays_o.device)[None] if is_train: rng = rng.repeat(rays_d.shape[-2],1) rng += torch.rand_like(rng[:,[0]]) - step = stepsize * rng.to(rays_o.device) + step = stepsize * rng interpx = (t_min[...,None] + step) rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] @@ -211,10 +214,10 @@ def getDenseAlpha(self,gridSize=None): gridSize = self.gridSize if gridSize is None else gridSize samples = torch.stack(torch.meshgrid( - torch.linspace(0, 1, gridSize[0]), - torch.linspace(0, 1, gridSize[1]), - torch.linspace(0, 1, gridSize[2]), - ), -1).to(self.device) + torch.linspace(0, 1, gridSize[0], device=self.device), + torch.linspace(0, 1, gridSize[1], device=self.device), + torch.linspace(0, 1, gridSize[2], device=self.device), + ), -1) dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples # dense_xyz = dense_xyz @@ -376,4 +379,4 @@ def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_sa # assert (normal_map.amin(1) >= 0).all() return rgb_map, depth_map, normal_map - return rgb_map, depth_map \ No newline at end of file + return rgb_map, depth_map diff --git a/PREF/train.py b/PREF/train.py index 09c3cc7..d8a9482 100644 --- a/PREF/train.py +++ b/PREF/train.py @@ -24,9 +24,9 @@ def __init__(self, total, batch): self.ids = None def nextids(self): - self.curr+=self.batch + self.curr += self.batch if self.curr + self.batch > self.total: - self.ids = torch.LongTensor(np.random.permutation(self.total)) + self.ids = torch.randperm(self.total).cuda() self.curr = 0 return self.ids[self.curr:self.curr+self.batch] @@ -179,6 +179,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs if not args.ndc_ray: allrays, allrgbs = phasorf.filtering_rays(allrays, allrgbs, bbox_only=True) + allrays = allrays.cuda() + allrgbs = allrgbs.cuda() trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size) @@ -266,6 +268,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False # filter rays outside the bbox allrays,allrgbs = phasorf.filtering_rays(allrays,allrgbs) trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size) + allrays = allrays.cuda() + allrgbs = allrgbs.cuda() # TODO: if upsamp_list is not None and iteration in upsamp_list: