diff --git a/models/grid_based.py b/models/grid_based.py index 1b7264d..54c7fe8 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -50,7 +50,6 @@ def forward(self, coords): outputs = torch.stack( [torch.cos(torch.pi * coords * self.get_freqs().unsqueeze(-1)), torch.sin(torch.pi * coords * self.get_freqs().unsqueeze(-1))], 1) - outputs = 2 * (coefs * outputs.repeat(1, self.n_chan//2, 1, 1)) ''' @@ -69,6 +68,90 @@ def get_freqs(self): * np.log2(self.freq_resolution)) +class VQ(nn.Module): + def __init__(self, resolution: int, n_chan: int, n_freq=None, + freq_resolution=None, bitwidth=4, grid_num=1, channel_wise=True): + # assume 3 axes have the same resolution + super().__init__() + self.resolution = resolution + self.n_chan = n_chan + if freq_resolution is None: + freq_resolution = resolution + self.freq_resolution = freq_resolution + + if n_freq is None: + n_freq = int(np.ceil(np.log2(freq_resolution))) + self.n_freq = n_freq + + self.freqs = nn.Parameter(torch.linspace(0., 1, self.n_freq), + requires_grad=False) + + # Assume that each channel has its own codebook + + self.bitwidth = bitwidth + self.channel_wise = channel_wise + if channel_wise: + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) + else : + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) + + + def forward(self, coords): + # [B, 3] to [1, B, 1, 3] + coords = coords.reshape(1, -1, 1, coords.shape[-1]) + + # coefs: [3, 1, C, B] + grid = self.get_grid() + coefs = F.grid_sample(grid, torch.cat([coords[..., (1, 2)], + coords[..., (0, 2)], + coords[..., (0, 1)]], 0), + mode='bilinear', + padding_mode='zeros', align_corners=True) + coefs = coefs.squeeze(-1).permute(2, 1, 0) # [B, C*F, 3] + coefs = coefs.reshape(coefs.size(0), self.n_chan, -1, 3) # [B, C, F, 3] + + # numerical integration + coords = coords.squeeze(0) # [B, 1, 3] + + ''' + # POS ENCODING + outputs = torch.stack( + [torch.cos(torch.pi * coords * self.get_freqs().unsqueeze(-1)), + torch.sin(torch.pi * coords * self.get_freqs().unsqueeze(-1))], 1) + + outputs = 2 * (coefs * outputs.repeat(1, self.n_chan//2, 1, 1)) + ''' + + coords = (coords + 1) / 2 * (self.resolution - 1) + outputs = torch.cos(torch.pi / self.resolution * coords + * self.get_freqs().unsqueeze(-1)) + outputs = 2 * (coefs * outputs.unsqueeze(-3)) # [B, C, F, 3] + return outputs.reshape(outputs.shape[0], -1) + + def compute_tv(self): + weight = self.get_freqs().repeat(self.n_chan).reshape(-1, 1, 1) + return (self.get_grid() * weight).square().mean() + + def get_freqs(self): + return -1 + 2**(self.freqs.clamp(min=0, max=1) + * np.log2(self.freq_resolution)) + + def get_grid(self): + if self.channel_wise: + softened_mat = torch.softmax(self.indices, dim=3) + grid_chan = [] + for i in range(self.n_chan): + grid_chan.append(softened_mat[:,i,...] @ self.codebook[...,i,:]) + grid = torch.stack(grid_chan, dim=1) + + else: + softened_mat = torch.softmax(self.indices, dim=2) + grid = softened_mat @ self.codebook + return grid.view(3, -1, self.resolution, self.resolution) + + class PREF(nn.Module): def __init__(self, res, ch): """