Skip to content

Commit

Permalink
change sum to cat (phaseMLP.CPhaseMLP)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 30, 2022
1 parent a45ff8a commit 49f0acb
Showing 1 changed file with 41 additions and 83 deletions.
124 changes: 41 additions & 83 deletions PREF/models/phasoMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .phasoBase import *
from .utils import positional_encoding
from .utils_fft import (
getMask_fft, getMask, grid_sample, grid_sample_cmplx, irfft, rfft, batch_irfft
getMask_fft, getMask, grid_sample, grid_sample_cmplx, irfft, rfft,
batch_irfft
)
from .cosine_transform import idctn

Expand All @@ -30,21 +31,14 @@ def init_phasor_volume(self, res, device):
self.init_(self.den_num_comp,
(self.gridSize * self.den_scale).long(),
ksize=self.den_ksize))
breakpoint()
self.app = torch.nn.ParameterList(
self.init_(self.app_num_comp,
(self.gridSize * self.app_scale).long(),
ksize=self.app_ksize))

den_outdim = self.den_ksize
app_outdim = self.app_ksize

# duplicate real and image parts
'''
if self.den_num_comp == [1, 1, 1]:
den_outdim = den_outdim * 2
if self.app_num_comp == [1, 1, 1]:
app_outdim = app_outdim * 2
'''
den_outdim = self.den_ksize * 3
app_outdim = self.app_ksize * 3

if self.app_aug == 'flip':
app_outdim = app_outdim * 2
Expand All @@ -68,19 +62,17 @@ def init_phasor_volume(self, res, device):

@torch.no_grad()
def compute_ktraj(self, axis, res):
ktraj2d = [torch.arange(i, dtype=torch.float32, device=self.device) for i in res]
ktraj2d = [torch.arange(i, dtype=torch.float32, device=self.device)
for i in res]
ktraj1d = [torch.arange(ax, dtype=torch.float32, device=self.device)
if type(ax) == int else ax for ax in axis]

ktrajx = torch.stack(
torch.meshgrid([ktraj1d[0], ktraj2d[1], ktraj2d[2]]),
dim=-1)
torch.meshgrid([ktraj1d[0], ktraj2d[1], ktraj2d[2]]), dim=-1)
ktrajy = torch.stack(
torch.meshgrid([ktraj2d[0], ktraj1d[1], ktraj2d[2]]),
dim=-1)
torch.meshgrid([ktraj2d[0], ktraj1d[1], ktraj2d[2]]), dim=-1)
ktrajz = torch.stack(
torch.meshgrid([ktraj2d[0], ktraj2d[1], ktraj1d[2]]),
dim=-1)
torch.meshgrid([ktraj2d[0], ktraj2d[1], ktraj1d[2]]), dim=-1)

return [ktrajx, ktrajy, ktrajz]

Expand Down Expand Up @@ -133,7 +125,6 @@ def compute_appfeature(self, xyz_sampled):
def compute_fft(self, features, xyz_sampled, interp=True):
if interp:
# Nx: num of samples
# using interpolation to compute fft = (N*N) log (N) d + (N*N*d*d) + Nsamples
breakpoint() # why interp?
fx, fy, fz = self.compute_spatial_volume(features)
volume = fx+fy+fz
Expand Down Expand Up @@ -171,23 +162,10 @@ def compute_fft(self, features, xyz_sampled, interp=True):
fyy = batch_irfft(fy, ys, ky, Ny)
fzz = batch_irfft(fz, zs, kz, Nz)

return fxx+fyy+fzz
return torch.cat([fxx, fyy, fzz], 0) # fxx+fyy+fzz

return points

def compute_spatial_volume(self, features):
breakpoint()
Fx, Fy, Fz = features
Nx, Ny, Nz = Fy.shape[2], Fz.shape[3], Fx.shape[4]
xx, yy, zz = [torch.linspace(0, 1, N).to(self.device) for N in [Nx, Ny, Nz]]
d1, d2, d3 = Fx.shape[2], Fy.shape[3], Fz.shape[4]
kx, ky, kz = self.axis
kx, ky, kz = kx[:d1], ky[:d2], kz[:d3]
fx = irfft(torch.fft.ifftn(Fx, dim=(3,4), norm='forward'), xx, ff=kx, T=Nx, dim=2)
fy = irfft(torch.fft.ifftn(Fy, dim=(2,4), norm='forward'), yy, ff=ky, T=Ny, dim=3)
fz = irfft(torch.fft.ifftn(Fz, dim=(2,3), norm='forward'), zz, ff=kz, T=Nz, dim=4)
return (fx, fy, fz)

def Parseval_Loss(self):
# Parseval Loss i.e., suppressing higher frequencies
# avoid higher freqeuncies explaining everything
Expand All @@ -201,13 +179,11 @@ def compute_normal(self, xyz_sampled):
with torch.enable_grad():
xyz_sampled.requires_grad = True
outs = self.compute_densityfeature(xyz_sampled)
d_points = torch.ones_like(outs, requires_grad=False, device=self.device)
normal = grad(
outputs=outs,
inputs=xyz_sampled,
grad_outputs=d_points,
retain_graph=False,
only_inputs=True)[0]
d_points = torch.ones_like(outs, requires_grad=False,
device=self.device)
normal = grad(outputs=outs, inputs=xyz_sampled,
grad_outputs=d_points, retain_graph=False,
only_inputs=True)[0]
normal = normal.T
normal = normal / torch.linalg.norm(normal, dim=0, keepdims=True)
return normal.detach()
Expand All @@ -218,60 +194,41 @@ def upsample_volume_grid(self, res_target):
res_den = [math.ceil(n * self.den_scale) for n in res_target]
res_app = [math.ceil(n * self.app_scale) for n in res_target]

new_den = self.upsample_fft(self.den, res_den)
self.den = torch.nn.ParameterList([torch.nn.Parameter(den) for den in new_den])
new_app = self.upsample_fft(self.app, res_app)
self.app = torch.nn.ParameterList([torch.nn.Parameter(app) for app in new_app])
new_den = self.upsample_feats(self.den, res_den)
self.den = torch.nn.ParameterList([torch.nn.Parameter(den)
for den in new_den])

new_app = self.upsample_feats(self.app, res_app)
self.app = torch.nn.ParameterList([torch.nn.Parameter(app)
for app in new_app])

self.print_size()
self.update_stepSize(res_target)
print(f'upsamping to {res_target}')

def upsample_fft(self, features, res_target, update_dd=False):
def upsample_feats(self, features, res_target, update_dd=False):
Tx, Ty, Tz = res_target
Fkx, Fky, Fkz = features
d1, d2, d3 = Fkx.shape[2], Fky.shape[3], Fkz.shape[4]
Nx, Ny, Nz = Fky.shape[2], Fkz.shape[3], Fkx.shape[4]

if update_dd:
t1, t2, t3 = d1, d2, d3
d1, d2, d3 = [int(np.log2(d))+1+2 for d in res_target]
self.den_num_comp = [d1, d2, d3]
self.axis = [torch.tensor([0.]+[2**i for i in torch.arange(d-1)]).to(self.device) for d in self.den_num_comp]

maskx = getMask([Ny, Nz], [Ty, Tz]).unsqueeze(0).repeat(d1,1,1)
masky = getMask([Nx, Nz], [Tx, Tz]).unsqueeze(1).repeat(1,d2,1)
maskz = getMask([Nx, Ny], [Tx, Ty]).unsqueeze(2).repeat(1,1,d3)

if update_dd:
maskx[t1:, :, :] = False
masky[:, t2:, :] = False
maskz[:, :, t3:] = False

new_Fkx = torch.zeros(*Fkx.shape[:2], d1, Ty, Tz).to(Fkx)
new_Fky = torch.zeros(*Fky.shape[:2], Tx, d2, Tz).to(Fky)
new_Fkz = torch.zeros(*Fkz.shape[:2], Tx, Ty, d3).to(Fkz)

try:
new_Fkx[..., maskx] = Fkx[:, :, :d1, :, :].flatten(2)
new_Fky[..., masky] = Fky[:, :, :, :d2, :].flatten(2)
new_Fkz[..., maskz] = Fkz[:, :, :, :, :d3].flatten(2)
except:
raise ValueError("Error")

return new_Fkx, new_Fky, new_Fkz
return F.pad(Fkx, (0, Tz-Nz, 0, Ty-Ny, 0, 0)), \
F.pad(Fky, (0, Tz-Nz, 0, 0, 0, Tx-Nx)), \
F.pad(Fkz, (0, 0, 0, Ty-Ny, 0, Tx-Nx))

def update_stepSize(self, gridSize):
self.ktraj = self.compute_ktraj(self.axis, gridSize)
self.ktraj_den = self.compute_ktraj(self.axis, [math.ceil(n * self.den_scale) for n in gridSize])
print("dimensions largest ", [torch.max(ax).item() for ax in self.axis])
self.ktraj_den = self.compute_ktraj(
self.axis, [math.ceil(n * self.den_scale) for n in gridSize])
print("dimensions largest ", [torch.max(ax).item() for ax in self.axis])
return super(CPhasoMLP, self).update_stepSize(gridSize)

def print_size(self):
print(self)
print(f' ==> Actual Model Size {np.sum([v.numel() * v.element_size() for k, v in self.named_parameters()])/2**20} MB')
for k,v in self.named_parameters():
print(f'Model Size ({k}) : {v.numel() * v.element_size()/2**20:.4f} MB')
print(f'Model Size ({k}) : '
f'{v.numel() * v.element_size()/2**20:.4f} MB')

def get_optparam_groups(self, lr_init_spatialxyz=0.02,
lr_init_network=0.001):
Expand All @@ -288,7 +245,8 @@ def get_optparam_groups(self, lr_init_spatialxyz=0.02,

@property
def alpha(self):
return F.softplus(self.alpha_params, beta=10, threshold=1e-4) # avoid negative value
# avoid negative value
return F.softplus(self.alpha_params, beta=10, threshold=1e-4)

@property
def density(self):
Expand All @@ -299,6 +257,7 @@ def appearance(self):
return [app * self.beta for app in self.app]

def compute_gaussian(self, variance, mode, ktraj=None):
breakpoint()
if mode== 'none':
return [1., 1., 1.]
if mode == 'prod':
Expand Down Expand Up @@ -446,6 +405,7 @@ def compute_fft(self, features, xyz_sampled, interp=True):
kx, ky, kz = self.axis
kx, ky, kz = kx[:d1], ky[:d2], kz[:d3]
xs, ys, zs = xyz_sampled.chunk(3, dim=-1)

Fx = torch.fft.ifftn(Fx, dim=(3,4), norm='forward')
Fy = torch.fft.ifftn(Fy, dim=(2,4), norm='forward')
Fz = torch.fft.ifftn(Fz, dim=(2,3), norm='forward')
Expand Down Expand Up @@ -492,13 +452,11 @@ def compute_normal(self, xyz_sampled):
with torch.enable_grad():
xyz_sampled.requires_grad = True
outs = self.compute_densityfeature(xyz_sampled)
d_points = torch.ones_like(outs, requires_grad=False, device=self.device)
normal = grad(
outputs=outs,
inputs=xyz_sampled,
grad_outputs=d_points,
retain_graph=False,
only_inputs=True)[0]
d_points = torch.ones_like(outs, requires_grad=False,
device=self.device)
normal = grad(outputs=outs, inputs=xyz_sampled,
grad_outputs=d_points, retain_graph=False,
only_inputs=True)[0]
normal = normal.T
normal = normal / torch.linalg.norm(normal, dim=0, keepdims=True)
return normal.detach()
Expand Down

0 comments on commit 49f0acb

Please sign in to comment.