Skip to content

Commit

Permalink
end of the day (7.31)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 31, 2022
1 parent 4765688 commit 1787b9a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 112 deletions.
111 changes: 5 additions & 106 deletions PREF/models/phasoMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ def init_phasor_volume(self, res, device):

# mask
self.den_mask = nn.ParameterList(
# [nn.Parameter(torch.zeros_like(d[0, 0])) for d in self.den])
[nn.Parameter(torch.zeros_like(d)) for d in self.den])
self.app_mask = nn.ParameterList(
# [nn.Parameter(torch.zeros_like(d[0, 0])) for d in self.app])
[nn.Parameter(torch.zeros_like(d)) for d in self.app])

den_outdim = self.den_ksize * 3
Expand Down Expand Up @@ -93,7 +91,6 @@ def init_(self, axis, res, ksize=1, init_scale=1):
return [nn.Parameter(fx), nn.Parameter(fy), nn.Parameter(fz)]

def compute_densityfeature(self, xyz_sampled):
# sigma_feature = self.compute_fft(self.density, xyz_sampled)
sigma_feature = self.compute_fft(self.density, xyz_sampled,
self.den_mask)
return self.mlp(sigma_feature.T).T
Expand All @@ -106,11 +103,11 @@ def feature2density(self, density_features):
return F.relu(density_features)

def compute_appfeature(self, xyz_sampled):
# app_points = self.compute_fft(self.appearance, xyz_sampled)
app_points = self.compute_fft(self.appearance, xyz_sampled,
self.app_mask)
if self.app_aug == 'flip':
aug = self.compute_fft(self.appearance, xyz_sampled.flip(-1))
aug = self.compute_fft(self.appearance, xyz_sampled.flip(-1),
self.app_mask)
app_points = torch.cat([app_points, aug], dim=0)
elif self.app_aug == 'normal':
aug = self.compute_normal(xyz_sampled)
Expand All @@ -135,33 +132,6 @@ def compute_fft(self, features, xyz_sampled, mask=None):
kx, ky, kz = kx[:d1], ky[:d2], kz[:d3]
xs, ys, zs = xyz_sampled.chunk(3, dim=-1)

'''
# [1, F, N, N, N]
pad_size = [(8 - x%8) % 8 for x in [Nx, Ny, Nz]]
Fx = F.pad(Fx, (0, pad_size[-1], 0, pad_size[-2]))
Fx = Fx.reshape(*Fx.shape[:-2], Fx.shape[-2]//8, 8,
Fx.shape[-1]//8, 8)
Fx = idctn(Fx, axes=(-3, -1))
Fx = Fx.flatten(-2, -1).flatten(-3, -2)[..., :Ny, :Nz]
Fy = F.pad(Fy, (0, pad_size[-1], 0, 0, 0, pad_size[-3]))
Fy = Fy.reshape(*Fy.shape[:-3], Fy.shape[-3]//8, 8, Fy.shape[-2],
Fy.shape[-1]//8, 8)
Fy = idctn(Fy, axes=(-4, -1))
Fy = Fy.flatten(-2, -1).flatten(-4, -3)[..., :Nx, :, :Nz]
Fz = F.pad(Fz, (0, 0, 0, pad_size[-2], 0, pad_size[-3]))
Fz = Fz.reshape(*Fz.shape[:-3], Fz.shape[-3]//8, 8,
Fz.shape[-2]//8, 8, Fz.shape[-1])
Fz = idctn(Fz, axes=(-4, -2))
Fz = Fz.flatten(-3, -2).flatten(-4, -3)[..., :Nx, :Ny, :]
assert tuple(Fx.shape[-3:]) == (d1, Ny, Nz)
assert tuple(Fy.shape[-3:]) == (Nx, d2, Nz)
assert tuple(Fz.shape[-3:]) == (Nx, Ny, d3)
'''

if mask is not None:
mx, my, mz = mask
mx = torch.sigmoid(mx)
Expand Down Expand Up @@ -207,10 +177,6 @@ def Parseval_Loss(self):
for f, w in zip(self.density, self.ktraj_den):
feat = torch.pi * f[..., None] * w.reshape(1, 1, *f.shape[2:], -1)
loss = loss + feat.square().mean()

# for f, w in zip(self.appearance, self.ktraj):
# feat = torch.pi * f[..., None] * w.reshape(1, 1, *f.shape[2:], -1)
# loss += feat.square().mean()
return loss

def compute_normal(self, xyz_sampled):
Expand Down Expand Up @@ -259,59 +225,6 @@ def upsample_feats(self, features, res_target, update_dd=False):
return F.pad(Fx, (0, Tz-Nz, 0, Ty-Ny, 0, 0)), \
F.pad(Fy, (0, Tz-Nz, 0, 0, 0, Tx-Nx)), \
F.pad(Fz, (0, 0, 0, Ty-Ny, 0, Tx-Nx))
'''
pad_size = [(8 - x%8) % 8 for x in [Nx, Ny, Nz]]
Fx = F.pad(Fx, (0, pad_size[-1], 0, pad_size[-2]))
Fx = Fx.reshape(*Fx.shape[:-2], Fx.shape[-2]//8, 8,
Fx.shape[-1]//8, 8)
Fx = idctn(Fx, axes=(-3, -1))
Fx = Fx.flatten(-2, -1).flatten(-3, -2)[..., :Ny, :Nz]
Fy = F.pad(Fy, (0, pad_size[-1], 0, 0, 0, pad_size[-3]))
Fy = Fy.reshape(*Fy.shape[:-3], Fy.shape[-3]//8, 8, Fy.shape[-2],
Fy.shape[-1]//8, 8)
Fy = idctn(Fy, axes=(-4, -1))
Fy = Fy.flatten(-2, -1).flatten(-4, -3)[..., :Nx, :, :Nz]
Fz = F.pad(Fz, (0, 0, 0, pad_size[-2], 0, pad_size[-3]))
Fz = Fz.reshape(*Fz.shape[:-3], Fz.shape[-3]//8, 8,
Fz.shape[-2]//8, 8, Fz.shape[-1])
Fz = idctn(Fz, axes=(-4, -2))
Fz = Fz.flatten(-3, -2).flatten(-4, -3)[..., :Nx, :Ny, :]
# interpolate
Fx = F.interpolate(Fx, (d1, Ty, Tz), mode='trilinear',
align_corners=True)
Fy = F.interpolate(Fy, (Tx, d2, Tz), mode='trilinear',
align_corners=True)
Fz = F.interpolate(Fz, (Tx, Ty, d3), mode='trilinear',
align_corners=True)
# dctn
pad_size = [(8 - x%8) % 8 for x in [Tx, Ty, Tz]]
Fx = F.pad(Fx, (0, pad_size[-1], 0, pad_size[-2]))
Fx = Fx.reshape(*Fx.shape[:-2], Fx.shape[-2]//8, 8,
Fx.shape[-1]//8, 8)
Fx = dctn(Fx, axes=(-3, -1))
Fx = Fx.flatten(-2, -1).flatten(-3, -2)[..., :Ty, :Tz]
Fy = F.pad(Fy, (0, pad_size[-1], 0, 0, 0, pad_size[-3]))
Fy = Fy.reshape(*Fy.shape[:-3], Fy.shape[-3]//8, 8, Fy.shape[-2],
Fy.shape[-1]//8, 8)
Fy = dctn(Fy, axes=(-4, -1))
Fy = Fy.flatten(-2, -1).flatten(-4, -3)[..., :Tx, :, :Tz]
Fz = F.pad(Fz, (0, 0, 0, pad_size[-2], 0, pad_size[-3]))
Fz = Fz.reshape(*Fz.shape[:-3], Fz.shape[-3]//8, 8,
Fz.shape[-2]//8, 8, Fz.shape[-1])
Fz = dctn(Fz, axes=(-4, -2))
Fz = Fz.flatten(-3, -2).flatten(-4, -3)[..., :Tx, :Ty, :]
return Fx, Fy, Fz
'''

def update_stepSize(self, gridSize):
self.ktraj = self.compute_ktraj(self.axis, gridSize)
Expand Down Expand Up @@ -342,11 +255,6 @@ def get_optparam_groups(self, lr_init_spatialxyz=0.02,
'lr':lr_init_network}]
return grad_vars

@property
def alpha(self):
# avoid negative value
return F.softplus(self.alpha_params, beta=10, threshold=1e-4)

@property
def density(self):
return [self.alpha * den for den in self.den]
Expand All @@ -355,18 +263,9 @@ def density(self):
def appearance(self):
return [app * self.beta for app in self.app]

def compute_gaussian(self, variance, mode, ktraj=None):
breakpoint()
if mode== 'none':
return [1., 1., 1.]
if mode == 'prod':
if ktraj is None:
ktraj = self.ktraj
ktraj = [k / g for k,g in zip(ktraj, self.gridSize)]
gauss = [torch.exp((-2*(np.pi*kk)**2*variance[None]).sum(-1)).reshape(1,1,*kk.shape[:-1]) for kk in ktraj]
else:
raise ValueError(f'no mode named {mode}')
return gauss
@property
def alpha(self):
return F.softplus(self.alpha_params, beta=10, threshold=1e-4)


class PhasoMLP(PhasorBase):
Expand Down
8 changes: 2 additions & 6 deletions PREF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def reconstruction(args, return_bbox=False, return_memory=False,
json.dump(args.__dict__, open(f'{logfolder}/config.json',mode='w'),indent=2)

# init parameters
if not bbox_only and args.dataset_name=='blender':
if not bbox_only and args.dataset_name == 'blender':
# use tight bbox pre-extracted and stored in misc.py,
# which takes 2k iters
data = args.datadir.split('/')[-1]
Expand Down Expand Up @@ -186,7 +186,7 @@ def reconstruction(args, return_bbox=False, return_memory=False,

optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))

#linear in logrithmic space
# linear in logrithmic space
if upsamp_list:
N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init),
np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
Expand Down Expand Up @@ -231,10 +231,6 @@ def reconstruction(args, return_bbox=False, return_memory=False,
loss_tv.detach().item(),
global_step=iteration)

# mask
# total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.den_mask])
# total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.app_mask])

if TV_weight_app > 0:
TV_weight_app *= lr_factor
raise NotImplementedError('not implemented')
Expand Down

0 comments on commit 1787b9a

Please sign in to comment.