Skip to content

Commit

Permalink
add channel-wise codebook
Browse files Browse the repository at this point in the history
  • Loading branch information
benhenryL committed Jul 23, 2022
1 parent 92b7e0c commit a12e783
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions models/grid_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class FreqGrid(nn.Module):
def __init__(self, resolution: int, n_chan: int, n_freq=None,
freq_resolution=None, bitwidth=3, grid_num=1):
freq_resolution=None, bitwidth=3, grid_num=1, channel_wise=True):
# assume 3 axes have the same resolution
super().__init__()
self.resolution = resolution
Expand All @@ -31,8 +31,13 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None,
# Assume that each channel has its own codebook

self.bitwidth = bitwidth
self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True)
self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate
self.channel_wise = channel_wise
if channel_wise:
self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True)
self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate
else :
self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True)
self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate


def forward(self, coords):
Expand Down Expand Up @@ -77,8 +82,16 @@ def get_freqs(self):
* np.log2(self.freq_resolution))

def get_grid(self):
softened_mat = torch.softmax(self.indices, dim=2)
grid = softened_mat @ self.codebook
if self.channel_wise:
softened_mat = torch.softmax(self.indices, dim=3)
grid_chan = []
for i in range(self.n_chan):
grid_chan.append(softened_mat[:,i,...] @ self.codebook[...,i,:])
grid = torch.stack(grid_chan, dim=1)

else:
softened_mat = torch.softmax(self.indices, dim=2)
grid = softened_mat @ self.codebook
return grid.view(3, -1, self.resolution, self.resolution)


Expand Down

0 comments on commit a12e783

Please sign in to comment.