Skip to content

Commit

Permalink
src structure rework
Browse files Browse the repository at this point in the history
  • Loading branch information
timetoai committed Aug 31, 2023
1 parent ef535e0 commit 8c0edf4
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 225 deletions.
6 changes: 4 additions & 2 deletions src/timediffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .timediffusion import TD, TimeDiffusionProjector, TimeDiffusion, count_params, DimUniversalStandardScaler
from .utils import count_params, DimUniversalStandardScaler
from .models import TimeDiffusionProjector, TimeDiffusion
from .frameworks import TD

__all__ = [
# useful functions
Expand All @@ -9,4 +11,4 @@
"TimeDiffusionProjector",
"TimeDiffusion",
"TD"
]
]
227 changes: 4 additions & 223 deletions src/timediffusion/timediffusion.py → src/timediffusion/frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,227 +7,8 @@
import torch
from torch import nn

def count_params(model: nn.Module) -> int:
"""
counts number of model parameters
"""
res = 0
for param in model.parameters():
res += np.prod(param.shape)
return res

def get_appropriate_conv_layer(dims: int) -> nn.Module:
"""
returns appropriate convolutional layer for certain number of dimensionalities
"""
if dims not in (1, 2, 3):
raise NotImplementedError("Convolutional layer for dimensionalty {dims} not implemented")
return {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}[dims]


class DimUniversalStandardScaler:
"""
Universal class for normal scaling data
"""
def __init__(self, eps=1e-9):
self.eps = eps

def fit(self, data):
self.mu = data.mean()
self.std = data.std()
if isinstance(data, torch.Tensor):
self.mu = self.mu.item()
self.std = self.std.item()

def transform(self, data):
return (data - self.mu) / (self.std + self.eps)

def fit_transform(self, data):
self.fit(data)
return self.transform(data)

def inverse_transform(self, data):
return data * self.std + self.mu


class Chomp(nn.Module):
"""
cuts padding part of sequence
inspired by https://github.com/locuslab/TCN
"""
def __init__(self, chomp_size: int, dims: int=1):
"""
args:
`chomp_size` - padding size to cut off
`dims` - number of working dimensionalities, which needed to be chomped
"""
super().__init__()
self.chomp_size = chomp_size
if dims not in (1, 2, 3):
raise NotImplementedError(f"Chomp layer for {dims = } not implemented")
self.dims = dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 1:
return x[..., : - self.chomp_size].contiguous()
if self.dims == 2:
return x[..., : - self.chomp_size, : - self.chomp_size].contiguous()
if self.dims == 3:
return x[..., : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
"""
combination of (convolutional layer, chomp, relu, dropout) repeated `layers` times
adds additional convolutional layer if needed to downsample number of channels
inspired by https://github.com/locuslab/TCN
"""
def __init__(self, n_inputs: int, n_outputs: int, kernel_size: Union[int, tuple[int]],
stride: Union[int, tuple[int]], dilation: Union[int, tuple[int]], padding: Union[int, tuple[int]],
dropout: int = 0.2, dims: int = 1, layers: int = 2):
super().__init__()

conv_layer = get_appropriate_conv_layer(dims)
self.padding = padding
self.dropout = dropout

net = []
for i in range(layers):
net.append(torch.nn.utils.weight_norm(conv_layer(
(n_inputs if i == 0 else n_outputs), n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)))
if padding > 0:
net.append(Chomp(padding, dims))
net.append(nn.ReLU())
if dropout > 0:
net.append(nn.Dropout(dropout))
self.net = nn.ModuleList(net)

self.downsample = conv_layer(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()

def init_weights(self):
"""
sets normal weight distribution for convolutional layers
"""
for i in range(0, len(self.net), 2 + (self.dropout > 0) + (self.padding > 0)):
self.net[i].weight.data.normal_(0, 0.5)

if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.5)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
`x` in format [batch_size, channels, *other_dims]
"""
out = x
for i in range(len(self.net)):
out = self.net[i](out)

res = x if self.downsample is None else self.downsample(x)
return out, self.relu(out + res)


class TimeDiffusionProjector(nn.Module):
"""
convolutional network, used as projector in TD
consists of temporal blocks with exponentially increasing padding/dilation parameters
"""
def __init__(self, input_dims: Union[list[int], tuple[int]], max_deg_constraint: int = 13,
conv_filters: int = 128, base_dropout: float = 0.05):
"""
args:
`input_dims` - [channels, *dims]
needed for dynamical network building
best way to pass it as `x.shape` (without batches)
`max_deg_constraint` - constraint to lessen network size, if not big enough will worsen model quality
number of temporal blocks in network will be (1 + max_deg_constraint) maximum
`conv_filters` - number of convolutional filters for each layer
`base_dropout` - dropout for first temporal block
"""
super().__init__()

self.input_dims = input_dims
self.dims = len(input_dims) - 1
self.channels = input_dims[0]
self.max_seq = max(input_dims[1:])
self.max_deg = int(np.ceil(np.log2(self.max_seq)))
if max_deg_constraint < self.max_deg:
print(f"For better TimeDiffusion performance it's recommended to use max_deg_constraint ", end="")
print(f"with value{self.max_deg} for input with shape {input_dims}")
self.max_deg = max_deg_constraint
print(f"Setting current {self.max_deg = }")

self.tcn = nn.ModuleList(
[TemporalBlock(self.channels, conv_filters,
kernel_size=1, stride=1, dilation=1, padding=0, dropout=base_dropout, dims=self.dims),
*[TemporalBlock(conv_filters, conv_filters,
kernel_size=2, stride=1, dilation=i, padding=i, dropout=0.0, dims=self.dims)
for i in [2 ** i for i in range(self.max_deg + 1)]]
])

self.last = get_appropriate_conv_layer(self.dims)(conv_filters, self.channels, kernel_size=1, stride=1, dilation=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
skip_acc = None
for layer in self.tcn:
skip, x = layer(x)
if skip_acc is None:
skip_acc = skip
else:
skip_acc += skip
x = self.last(x + skip_acc)
return x


class TimeDiffusion(nn.Module):
"""
main model, uses projectors to create (q, k, v) for vanilla attention layer
"""
def __init__(self, *args, **params):
"""
`args`, `params` - parameters for projectors
"""
super().__init__()
self.key_proj = TimeDiffusionProjector(*args, **params)
self.val_proj = TimeDiffusionProjector(*args, **params)
self.query_proj = TimeDiffusionProjector(*args, **params)

self.input_dims = self.key_proj.input_dims
self.dims = self.key_proj.dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
# projections
key = self.key_proj(x)
val = self.val_proj(x)
query = self.query_proj(x)

is_batched = self.dims + 2 == len(key.size())
mat_mul = torch.bmm if is_batched else torch.matmul

# flattening last dimensionalities in case of 2D and 3D input
# TODO: think of better solution
if self.dims > 1:
orig_shape = key.shape
new_shape = list(key.size()[: - self.dims]) + [np.prod(key.size()[ - self.dims:])]

key = key.view(new_shape)
val = val.view(new_shape)
query = query.view(new_shape)

# vanilla attenion
scores = mat_mul(query, key.transpose(- 2, - 1))
weights = torch.nn.functional.softmax(scores, dim=1)
attention = mat_mul(weights, val)

# back to original shape in case of 2D and 3D input
if self.dims > 1:
attention = attention.view(orig_shape)

return attention
from .utils import count_params, DimUniversalStandardScaler
from .models import TimeDiffusion


class TD(nn.Module):
Expand Down Expand Up @@ -412,10 +193,10 @@ def restore(self, example: Union[None, np.array, torch.Tensor] = None, shape: Un

if example is None:
if shape is None:
raise ValueEror("Either `example` or `shape` should be passed")
raise ValueError("Either `example` or `shape` should be passed")

torch.random.manual_seed(seed)
X = torch.rand(*dims).to(device=self.device(), dtype=self.dtype())
X = torch.rand(*shape).to(device=self.device(), dtype=self.dtype())
else:
if len(self.input_dims) != len(example.shape):
raise ValueError(f"Model fitted with {len(self.input_dims)} dims, but got {len(example.shape)}")
Expand Down
83 changes: 83 additions & 0 deletions src/timediffusion/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Union

import torch
from torch import nn

from .utils import get_appropriate_conv_layer


class Chomp(nn.Module):
"""
cuts padding part of sequence
inspired by https://github.com/locuslab/TCN
"""
def __init__(self, chomp_size: int, dims: int=1):
"""
args:
`chomp_size` - padding size to cut off
`dims` - number of working dimensionalities, which needed to be chomped
"""
super().__init__()
self.chomp_size = chomp_size
if dims not in (1, 2, 3):
raise NotImplementedError(f"Chomp layer for {dims = } not implemented")
self.dims = dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 1:
return x[..., : - self.chomp_size].contiguous()
if self.dims == 2:
return x[..., : - self.chomp_size, : - self.chomp_size].contiguous()
if self.dims == 3:
return x[..., : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
"""
combination of (convolutional layer, chomp, relu, dropout) repeated `layers` times
adds additional convolutional layer if needed to downsample number of channels
inspired by https://github.com/locuslab/TCN
"""
def __init__(self, n_inputs: int, n_outputs: int, kernel_size: Union[int, tuple[int]],
stride: Union[int, tuple[int]], dilation: Union[int, tuple[int]], padding: Union[int, tuple[int]],
dropout: int = 0.2, dims: int = 1, layers: int = 2):
super().__init__()

conv_layer = get_appropriate_conv_layer(dims)
self.padding = padding
self.dropout = dropout

net = []
for i in range(layers):
net.append(torch.nn.utils.weight_norm(conv_layer(
(n_inputs if i == 0 else n_outputs), n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)))
if padding > 0:
net.append(Chomp(padding, dims))
net.append(nn.ReLU())
if dropout > 0:
net.append(nn.Dropout(dropout))
self.net = nn.ModuleList(net)

self.downsample = conv_layer(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()

def init_weights(self):
"""
sets normal weight distribution for convolutional layers
"""
for i in range(0, len(self.net), 2 + (self.dropout > 0) + (self.padding > 0)):
self.net[i].weight.data.normal_(0, 0.5)

if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.5)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
`x` in format [batch_size, channels, *other_dims]
"""
out = x
for i in range(len(self.net)):
out = self.net[i](out)

res = x if self.downsample is None else self.downsample(x)
return out, self.relu(out + res)
Loading

0 comments on commit 8c0edf4

Please sign in to comment.