Skip to content

Commit

Permalink
Merge pull request #1 from daniel03c1/dev/codebook
Browse files Browse the repository at this point in the history
Implement codebook(vector quantization)
  • Loading branch information
daniel03c1 authored Jul 26, 2022
2 parents 475975e + 6feb301 commit 304d6be
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion models/grid_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def forward(self, coords):
outputs = torch.stack(
[torch.cos(torch.pi * coords * self.get_freqs().unsqueeze(-1)),
torch.sin(torch.pi * coords * self.get_freqs().unsqueeze(-1))], 1)
outputs = 2 * (coefs * outputs.repeat(1, self.n_chan//2, 1, 1))
'''

Expand All @@ -69,6 +68,90 @@ def get_freqs(self):
* np.log2(self.freq_resolution))


class VQ(nn.Module):
def __init__(self, resolution: int, n_chan: int, n_freq=None,
freq_resolution=None, bitwidth=4, grid_num=1, channel_wise=True):
# assume 3 axes have the same resolution
super().__init__()
self.resolution = resolution
self.n_chan = n_chan
if freq_resolution is None:
freq_resolution = resolution
self.freq_resolution = freq_resolution

if n_freq is None:
n_freq = int(np.ceil(np.log2(freq_resolution)))
self.n_freq = n_freq

self.freqs = nn.Parameter(torch.linspace(0., 1, self.n_freq),
requires_grad=False)

# Assume that each channel has its own codebook

self.bitwidth = bitwidth
self.channel_wise = channel_wise
if channel_wise:
self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True)
self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True)
else :
self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True)
self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True)


def forward(self, coords):
# [B, 3] to [1, B, 1, 3]
coords = coords.reshape(1, -1, 1, coords.shape[-1])

# coefs: [3, 1, C, B]
grid = self.get_grid()
coefs = F.grid_sample(grid, torch.cat([coords[..., (1, 2)],
coords[..., (0, 2)],
coords[..., (0, 1)]], 0),
mode='bilinear',
padding_mode='zeros', align_corners=True)
coefs = coefs.squeeze(-1).permute(2, 1, 0) # [B, C*F, 3]
coefs = coefs.reshape(coefs.size(0), self.n_chan, -1, 3) # [B, C, F, 3]

# numerical integration
coords = coords.squeeze(0) # [B, 1, 3]

'''
# POS ENCODING
outputs = torch.stack(
[torch.cos(torch.pi * coords * self.get_freqs().unsqueeze(-1)),
torch.sin(torch.pi * coords * self.get_freqs().unsqueeze(-1))], 1)
outputs = 2 * (coefs * outputs.repeat(1, self.n_chan//2, 1, 1))
'''

coords = (coords + 1) / 2 * (self.resolution - 1)
outputs = torch.cos(torch.pi / self.resolution * coords
* self.get_freqs().unsqueeze(-1))
outputs = 2 * (coefs * outputs.unsqueeze(-3)) # [B, C, F, 3]
return outputs.reshape(outputs.shape[0], -1)

def compute_tv(self):
weight = self.get_freqs().repeat(self.n_chan).reshape(-1, 1, 1)
return (self.get_grid() * weight).square().mean()

def get_freqs(self):
return -1 + 2**(self.freqs.clamp(min=0, max=1)
* np.log2(self.freq_resolution))

def get_grid(self):
if self.channel_wise:
softened_mat = torch.softmax(self.indices, dim=3)
grid_chan = []
for i in range(self.n_chan):
grid_chan.append(softened_mat[:,i,...] @ self.codebook[...,i,:])
grid = torch.stack(grid_chan, dim=1)

else:
softened_mat = torch.softmax(self.indices, dim=2)
grid = softened_mat @ self.codebook
return grid.view(3, -1, self.resolution, self.resolution)


class PREF(nn.Module):
def __init__(self, res, ch):
"""
Expand Down

0 comments on commit 304d6be

Please sign in to comment.