From 92b7e0c54f6f09ead6bb09e33124809dacea3d6a Mon Sep 17 00:00:00 2001 From: blee Date: Fri, 22 Jul 2022 06:39:19 +0000 Subject: [PATCH 1/3] apply codebook ++ one codebook for each channel(Not shared within channels) --- models/grid_based.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/models/grid_based.py b/models/grid_based.py index 1b7264d..50275c1 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -8,7 +8,7 @@ class FreqGrid(nn.Module): def __init__(self, resolution: int, n_chan: int, n_freq=None, - freq_resolution=None): + freq_resolution=None, bitwidth=3, grid_num=1): # assume 3 axes have the same resolution super().__init__() self.resolution = resolution @@ -24,16 +24,23 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None, self.freqs = nn.Parameter(torch.linspace(0., 1, self.n_freq), requires_grad=False) - self.grid = nn.Parameter(nn.Parameter( - torch.zeros(3, n_chan*self.n_freq, resolution, resolution), - requires_grad=True)) + # self.grid = nn.Parameter(nn.Parameter( + # torch.zeros(3, n_chan*self.n_freq, resolution, resolution), + # requires_grad=True)) + + # Assume that each channel has its own codebook + + self.bitwidth = bitwidth + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + def forward(self, coords): # [B, 3] to [1, B, 1, 3] - coords = coords.reshape(1, -1, 1, coords.shape[-1]) + coords = coords.reshape(1, -1, 1, coords.shape[-1]) # coefs: [3, 1, C, B] - grid = self.grid + grid = self.get_grid() coefs = F.grid_sample(grid, torch.cat([coords[..., (1, 2)], coords[..., (0, 2)], coords[..., (0, 1)]], 0), @@ -62,12 +69,18 @@ def forward(self, coords): def compute_tv(self): weight = self.get_freqs().repeat(self.n_chan).reshape(-1, 1, 1) - return (self.grid * weight).square().mean() + # return (self.grid * weight).square().mean() + return (self.get_grid() * weight).square().mean() def get_freqs(self): return -1 + 2**(self.freqs.clamp(min=0, max=1) * np.log2(self.freq_resolution)) + def get_grid(self): + softened_mat = torch.softmax(self.indices, dim=2) + grid = softened_mat @ self.codebook + return grid.view(3, -1, self.resolution, self.resolution) + class PREF(nn.Module): def __init__(self, res, ch): From a12e783c7fbcf8a88c3e48c62cfd8b19be59b99d Mon Sep 17 00:00:00 2001 From: blee Date: Sat, 23 Jul 2022 02:53:35 +0000 Subject: [PATCH 2/3] add channel-wise codebook --- models/grid_based.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/models/grid_based.py b/models/grid_based.py index 50275c1..35f7369 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -8,7 +8,7 @@ class FreqGrid(nn.Module): def __init__(self, resolution: int, n_chan: int, n_freq=None, - freq_resolution=None, bitwidth=3, grid_num=1): + freq_resolution=None, bitwidth=3, grid_num=1, channel_wise=True): # assume 3 axes have the same resolution super().__init__() self.resolution = resolution @@ -31,8 +31,13 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None, # Assume that each channel has its own codebook self.bitwidth = bitwidth - self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) - self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + self.channel_wise = channel_wise + if channel_wise: + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + else : + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate def forward(self, coords): @@ -77,8 +82,16 @@ def get_freqs(self): * np.log2(self.freq_resolution)) def get_grid(self): - softened_mat = torch.softmax(self.indices, dim=2) - grid = softened_mat @ self.codebook + if self.channel_wise: + softened_mat = torch.softmax(self.indices, dim=3) + grid_chan = [] + for i in range(self.n_chan): + grid_chan.append(softened_mat[:,i,...] @ self.codebook[...,i,:]) + grid = torch.stack(grid_chan, dim=1) + + else: + softened_mat = torch.softmax(self.indices, dim=2) + grid = softened_mat @ self.codebook return grid.view(3, -1, self.resolution, self.resolution) From 6feb30117882009b0fb18b2442583cb30ca3337f Mon Sep 17 00:00:00 2001 From: blee Date: Mon, 25 Jul 2022 14:46:01 +0000 Subject: [PATCH 3/3] add VQ class --- models/grid_based.py | 71 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/models/grid_based.py b/models/grid_based.py index 35f7369..54c7fe8 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -8,7 +8,7 @@ class FreqGrid(nn.Module): def __init__(self, resolution: int, n_chan: int, n_freq=None, - freq_resolution=None, bitwidth=3, grid_num=1, channel_wise=True): + freq_resolution=None): # assume 3 axes have the same resolution super().__init__() self.resolution = resolution @@ -24,9 +24,67 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None, self.freqs = nn.Parameter(torch.linspace(0., 1, self.n_freq), requires_grad=False) - # self.grid = nn.Parameter(nn.Parameter( - # torch.zeros(3, n_chan*self.n_freq, resolution, resolution), - # requires_grad=True)) + self.grid = nn.Parameter(nn.Parameter( + torch.zeros(3, n_chan*self.n_freq, resolution, resolution), + requires_grad=True)) + + def forward(self, coords): + # [B, 3] to [1, B, 1, 3] + coords = coords.reshape(1, -1, 1, coords.shape[-1]) + + # coefs: [3, 1, C, B] + grid = self.grid + coefs = F.grid_sample(grid, torch.cat([coords[..., (1, 2)], + coords[..., (0, 2)], + coords[..., (0, 1)]], 0), + mode='bilinear', + padding_mode='zeros', align_corners=True) + coefs = coefs.squeeze(-1).permute(2, 1, 0) # [B, C*F, 3] + coefs = coefs.reshape(coefs.size(0), self.n_chan, -1, 3) # [B, C, F, 3] + + # numerical integration + coords = coords.squeeze(0) # [B, 1, 3] + + ''' + # POS ENCODING + outputs = torch.stack( + [torch.cos(torch.pi * coords * self.get_freqs().unsqueeze(-1)), + torch.sin(torch.pi * coords * self.get_freqs().unsqueeze(-1))], 1) + outputs = 2 * (coefs * outputs.repeat(1, self.n_chan//2, 1, 1)) + ''' + + coords = (coords + 1) / 2 * (self.resolution - 1) + outputs = torch.cos(torch.pi / self.resolution * coords + * self.get_freqs().unsqueeze(-1)) + outputs = 2 * (coefs * outputs.unsqueeze(-3)) # [B, C, F, 3] + return outputs.reshape(outputs.shape[0], -1) + + def compute_tv(self): + weight = self.get_freqs().repeat(self.n_chan).reshape(-1, 1, 1) + return (self.grid * weight).square().mean() + + def get_freqs(self): + return -1 + 2**(self.freqs.clamp(min=0, max=1) + * np.log2(self.freq_resolution)) + + +class VQ(nn.Module): + def __init__(self, resolution: int, n_chan: int, n_freq=None, + freq_resolution=None, bitwidth=4, grid_num=1, channel_wise=True): + # assume 3 axes have the same resolution + super().__init__() + self.resolution = resolution + self.n_chan = n_chan + if freq_resolution is None: + freq_resolution = resolution + self.freq_resolution = freq_resolution + + if n_freq is None: + n_freq = int(np.ceil(np.log2(freq_resolution))) + self.n_freq = n_freq + + self.freqs = nn.Parameter(torch.linspace(0., 1, self.n_freq), + requires_grad=False) # Assume that each channel has its own codebook @@ -34,10 +92,10 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None, self.channel_wise = channel_wise if channel_wise: self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True) - self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) else : self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) - self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) def forward(self, coords): @@ -74,7 +132,6 @@ def forward(self, coords): def compute_tv(self): weight = self.get_freqs().repeat(self.n_chan).reshape(-1, 1, 1) - # return (self.grid * weight).square().mean() return (self.get_grid() * weight).square().mean() def get_freqs(self):