From 49f0acb45b8f19a9ffaa96a3931ec31c51d32ffe Mon Sep 17 00:00:00 2001 From: daniel03c1 Date: Sat, 30 Jul 2022 14:19:29 +0000 Subject: [PATCH] change sum to cat (phaseMLP.CPhaseMLP) --- PREF/models/phasoMLP.py | 124 +++++++++++++--------------------------- 1 file changed, 41 insertions(+), 83 deletions(-) diff --git a/PREF/models/phasoMLP.py b/PREF/models/phasoMLP.py index a6feebc..124b827 100644 --- a/PREF/models/phasoMLP.py +++ b/PREF/models/phasoMLP.py @@ -5,7 +5,8 @@ from .phasoBase import * from .utils import positional_encoding from .utils_fft import ( - getMask_fft, getMask, grid_sample, grid_sample_cmplx, irfft, rfft, batch_irfft + getMask_fft, getMask, grid_sample, grid_sample_cmplx, irfft, rfft, + batch_irfft ) from .cosine_transform import idctn @@ -30,21 +31,14 @@ def init_phasor_volume(self, res, device): self.init_(self.den_num_comp, (self.gridSize * self.den_scale).long(), ksize=self.den_ksize)) + breakpoint() self.app = torch.nn.ParameterList( self.init_(self.app_num_comp, (self.gridSize * self.app_scale).long(), ksize=self.app_ksize)) - den_outdim = self.den_ksize - app_outdim = self.app_ksize - - # duplicate real and image parts - ''' - if self.den_num_comp == [1, 1, 1]: - den_outdim = den_outdim * 2 - if self.app_num_comp == [1, 1, 1]: - app_outdim = app_outdim * 2 - ''' + den_outdim = self.den_ksize * 3 + app_outdim = self.app_ksize * 3 if self.app_aug == 'flip': app_outdim = app_outdim * 2 @@ -68,19 +62,17 @@ def init_phasor_volume(self, res, device): @torch.no_grad() def compute_ktraj(self, axis, res): - ktraj2d = [torch.arange(i, dtype=torch.float32, device=self.device) for i in res] + ktraj2d = [torch.arange(i, dtype=torch.float32, device=self.device) + for i in res] ktraj1d = [torch.arange(ax, dtype=torch.float32, device=self.device) if type(ax) == int else ax for ax in axis] ktrajx = torch.stack( - torch.meshgrid([ktraj1d[0], ktraj2d[1], ktraj2d[2]]), - dim=-1) + torch.meshgrid([ktraj1d[0], ktraj2d[1], ktraj2d[2]]), dim=-1) ktrajy = torch.stack( - torch.meshgrid([ktraj2d[0], ktraj1d[1], ktraj2d[2]]), - dim=-1) + torch.meshgrid([ktraj2d[0], ktraj1d[1], ktraj2d[2]]), dim=-1) ktrajz = torch.stack( - torch.meshgrid([ktraj2d[0], ktraj2d[1], ktraj1d[2]]), - dim=-1) + torch.meshgrid([ktraj2d[0], ktraj2d[1], ktraj1d[2]]), dim=-1) return [ktrajx, ktrajy, ktrajz] @@ -133,7 +125,6 @@ def compute_appfeature(self, xyz_sampled): def compute_fft(self, features, xyz_sampled, interp=True): if interp: # Nx: num of samples - # using interpolation to compute fft = (N*N) log (N) d + (N*N*d*d) + Nsamples breakpoint() # why interp? fx, fy, fz = self.compute_spatial_volume(features) volume = fx+fy+fz @@ -171,23 +162,10 @@ def compute_fft(self, features, xyz_sampled, interp=True): fyy = batch_irfft(fy, ys, ky, Ny) fzz = batch_irfft(fz, zs, kz, Nz) - return fxx+fyy+fzz + return torch.cat([fxx, fyy, fzz], 0) # fxx+fyy+fzz return points - def compute_spatial_volume(self, features): - breakpoint() - Fx, Fy, Fz = features - Nx, Ny, Nz = Fy.shape[2], Fz.shape[3], Fx.shape[4] - xx, yy, zz = [torch.linspace(0, 1, N).to(self.device) for N in [Nx, Ny, Nz]] - d1, d2, d3 = Fx.shape[2], Fy.shape[3], Fz.shape[4] - kx, ky, kz = self.axis - kx, ky, kz = kx[:d1], ky[:d2], kz[:d3] - fx = irfft(torch.fft.ifftn(Fx, dim=(3,4), norm='forward'), xx, ff=kx, T=Nx, dim=2) - fy = irfft(torch.fft.ifftn(Fy, dim=(2,4), norm='forward'), yy, ff=ky, T=Ny, dim=3) - fz = irfft(torch.fft.ifftn(Fz, dim=(2,3), norm='forward'), zz, ff=kz, T=Nz, dim=4) - return (fx, fy, fz) - def Parseval_Loss(self): # Parseval Loss i.e., suppressing higher frequencies # avoid higher freqeuncies explaining everything @@ -201,13 +179,11 @@ def compute_normal(self, xyz_sampled): with torch.enable_grad(): xyz_sampled.requires_grad = True outs = self.compute_densityfeature(xyz_sampled) - d_points = torch.ones_like(outs, requires_grad=False, device=self.device) - normal = grad( - outputs=outs, - inputs=xyz_sampled, - grad_outputs=d_points, - retain_graph=False, - only_inputs=True)[0] + d_points = torch.ones_like(outs, requires_grad=False, + device=self.device) + normal = grad(outputs=outs, inputs=xyz_sampled, + grad_outputs=d_points, retain_graph=False, + only_inputs=True)[0] normal = normal.T normal = normal / torch.linalg.norm(normal, dim=0, keepdims=True) return normal.detach() @@ -218,60 +194,41 @@ def upsample_volume_grid(self, res_target): res_den = [math.ceil(n * self.den_scale) for n in res_target] res_app = [math.ceil(n * self.app_scale) for n in res_target] - new_den = self.upsample_fft(self.den, res_den) - self.den = torch.nn.ParameterList([torch.nn.Parameter(den) for den in new_den]) - new_app = self.upsample_fft(self.app, res_app) - self.app = torch.nn.ParameterList([torch.nn.Parameter(app) for app in new_app]) + new_den = self.upsample_feats(self.den, res_den) + self.den = torch.nn.ParameterList([torch.nn.Parameter(den) + for den in new_den]) + + new_app = self.upsample_feats(self.app, res_app) + self.app = torch.nn.ParameterList([torch.nn.Parameter(app) + for app in new_app]) self.print_size() self.update_stepSize(res_target) print(f'upsamping to {res_target}') - def upsample_fft(self, features, res_target, update_dd=False): + def upsample_feats(self, features, res_target, update_dd=False): Tx, Ty, Tz = res_target Fkx, Fky, Fkz = features d1, d2, d3 = Fkx.shape[2], Fky.shape[3], Fkz.shape[4] Nx, Ny, Nz = Fky.shape[2], Fkz.shape[3], Fkx.shape[4] - if update_dd: - t1, t2, t3 = d1, d2, d3 - d1, d2, d3 = [int(np.log2(d))+1+2 for d in res_target] - self.den_num_comp = [d1, d2, d3] - self.axis = [torch.tensor([0.]+[2**i for i in torch.arange(d-1)]).to(self.device) for d in self.den_num_comp] - - maskx = getMask([Ny, Nz], [Ty, Tz]).unsqueeze(0).repeat(d1,1,1) - masky = getMask([Nx, Nz], [Tx, Tz]).unsqueeze(1).repeat(1,d2,1) - maskz = getMask([Nx, Ny], [Tx, Ty]).unsqueeze(2).repeat(1,1,d3) - - if update_dd: - maskx[t1:, :, :] = False - masky[:, t2:, :] = False - maskz[:, :, t3:] = False - - new_Fkx = torch.zeros(*Fkx.shape[:2], d1, Ty, Tz).to(Fkx) - new_Fky = torch.zeros(*Fky.shape[:2], Tx, d2, Tz).to(Fky) - new_Fkz = torch.zeros(*Fkz.shape[:2], Tx, Ty, d3).to(Fkz) - - try: - new_Fkx[..., maskx] = Fkx[:, :, :d1, :, :].flatten(2) - new_Fky[..., masky] = Fky[:, :, :, :d2, :].flatten(2) - new_Fkz[..., maskz] = Fkz[:, :, :, :, :d3].flatten(2) - except: - raise ValueError("Error") - - return new_Fkx, new_Fky, new_Fkz + return F.pad(Fkx, (0, Tz-Nz, 0, Ty-Ny, 0, 0)), \ + F.pad(Fky, (0, Tz-Nz, 0, 0, 0, Tx-Nx)), \ + F.pad(Fkz, (0, 0, 0, Ty-Ny, 0, Tx-Nx)) def update_stepSize(self, gridSize): self.ktraj = self.compute_ktraj(self.axis, gridSize) - self.ktraj_den = self.compute_ktraj(self.axis, [math.ceil(n * self.den_scale) for n in gridSize]) - print("dimensions largest ", [torch.max(ax).item() for ax in self.axis]) + self.ktraj_den = self.compute_ktraj( + self.axis, [math.ceil(n * self.den_scale) for n in gridSize]) + print("dimensions largest ", [torch.max(ax).item() for ax in self.axis]) return super(CPhasoMLP, self).update_stepSize(gridSize) def print_size(self): print(self) print(f' ==> Actual Model Size {np.sum([v.numel() * v.element_size() for k, v in self.named_parameters()])/2**20} MB') for k,v in self.named_parameters(): - print(f'Model Size ({k}) : {v.numel() * v.element_size()/2**20:.4f} MB') + print(f'Model Size ({k}) : ' + f'{v.numel() * v.element_size()/2**20:.4f} MB') def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001): @@ -288,7 +245,8 @@ def get_optparam_groups(self, lr_init_spatialxyz=0.02, @property def alpha(self): - return F.softplus(self.alpha_params, beta=10, threshold=1e-4) # avoid negative value + # avoid negative value + return F.softplus(self.alpha_params, beta=10, threshold=1e-4) @property def density(self): @@ -299,6 +257,7 @@ def appearance(self): return [app * self.beta for app in self.app] def compute_gaussian(self, variance, mode, ktraj=None): + breakpoint() if mode== 'none': return [1., 1., 1.] if mode == 'prod': @@ -446,6 +405,7 @@ def compute_fft(self, features, xyz_sampled, interp=True): kx, ky, kz = self.axis kx, ky, kz = kx[:d1], ky[:d2], kz[:d3] xs, ys, zs = xyz_sampled.chunk(3, dim=-1) + Fx = torch.fft.ifftn(Fx, dim=(3,4), norm='forward') Fy = torch.fft.ifftn(Fy, dim=(2,4), norm='forward') Fz = torch.fft.ifftn(Fz, dim=(2,3), norm='forward') @@ -492,13 +452,11 @@ def compute_normal(self, xyz_sampled): with torch.enable_grad(): xyz_sampled.requires_grad = True outs = self.compute_densityfeature(xyz_sampled) - d_points = torch.ones_like(outs, requires_grad=False, device=self.device) - normal = grad( - outputs=outs, - inputs=xyz_sampled, - grad_outputs=d_points, - retain_graph=False, - only_inputs=True)[0] + d_points = torch.ones_like(outs, requires_grad=False, + device=self.device) + normal = grad(outputs=outs, inputs=xyz_sampled, + grad_outputs=d_points, retain_graph=False, + only_inputs=True)[0] normal = normal.T normal = normal / torch.linalg.norm(normal, dim=0, keepdims=True) return normal.detach()