diff --git a/models/grid_based.py b/models/grid_based.py index 50275c1..35f7369 100644 --- a/models/grid_based.py +++ b/models/grid_based.py @@ -8,7 +8,7 @@ class FreqGrid(nn.Module): def __init__(self, resolution: int, n_chan: int, n_freq=None, - freq_resolution=None, bitwidth=3, grid_num=1): + freq_resolution=None, bitwidth=3, grid_num=1, channel_wise=True): # assume 3 axes have the same resolution super().__init__() self.resolution = resolution @@ -31,8 +31,13 @@ def __init__(self, resolution: int, n_chan: int, n_freq=None, # Assume that each channel has its own codebook self.bitwidth = bitwidth - self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) - self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + self.channel_wise = channel_wise + if channel_wise: + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan, self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, self.n_chan, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate + else : + self.codebook = nn.Parameter(torch.normal(0, 0.1, size=(3, 2**bitwidth, self.n_chan * self.n_freq)), requires_grad=True) + self.indices = nn.Parameter(torch.zeros(3, resolution * resolution, 2**bitwidth), requires_grad=True) # Tensor C in Variable Bitrate def forward(self, coords): @@ -77,8 +82,16 @@ def get_freqs(self): * np.log2(self.freq_resolution)) def get_grid(self): - softened_mat = torch.softmax(self.indices, dim=2) - grid = softened_mat @ self.codebook + if self.channel_wise: + softened_mat = torch.softmax(self.indices, dim=3) + grid_chan = [] + for i in range(self.n_chan): + grid_chan.append(softened_mat[:,i,...] @ self.codebook[...,i,:]) + grid = torch.stack(grid_chan, dim=1) + + else: + softened_mat = torch.softmax(self.indices, dim=2) + grid = softened_mat @ self.codebook return grid.view(3, -1, self.resolution, self.resolution)