diff --git a/src/timediffusion/__init__.py b/src/timediffusion/__init__.py index 2e7c128..aada5c7 100644 --- a/src/timediffusion/__init__.py +++ b/src/timediffusion/__init__.py @@ -1,4 +1,6 @@ -from .timediffusion import TD, TimeDiffusionProjector, TimeDiffusion, count_params, DimUniversalStandardScaler +from .utils import count_params, DimUniversalStandardScaler +from .models import TimeDiffusionProjector, TimeDiffusion +from .frameworks import TD __all__ = [ # useful functions @@ -9,4 +11,4 @@ "TimeDiffusionProjector", "TimeDiffusion", "TD" -] \ No newline at end of file +] diff --git a/src/timediffusion/timediffusion.py b/src/timediffusion/frameworks.py similarity index 56% rename from src/timediffusion/timediffusion.py rename to src/timediffusion/frameworks.py index 38f154a..42521a8 100644 --- a/src/timediffusion/timediffusion.py +++ b/src/timediffusion/frameworks.py @@ -7,227 +7,8 @@ import torch from torch import nn -def count_params(model: nn.Module) -> int: - """ - counts number of model parameters - """ - res = 0 - for param in model.parameters(): - res += np.prod(param.shape) - return res - -def get_appropriate_conv_layer(dims: int) -> nn.Module: - """ - returns appropriate convolutional layer for certain number of dimensionalities - """ - if dims not in (1, 2, 3): - raise NotImplementedError("Convolutional layer for dimensionalty {dims} not implemented") - return {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}[dims] - - -class DimUniversalStandardScaler: - """ - Universal class for normal scaling data - """ - def __init__(self, eps=1e-9): - self.eps = eps - - def fit(self, data): - self.mu = data.mean() - self.std = data.std() - if isinstance(data, torch.Tensor): - self.mu = self.mu.item() - self.std = self.std.item() - - def transform(self, data): - return (data - self.mu) / (self.std + self.eps) - - def fit_transform(self, data): - self.fit(data) - return self.transform(data) - - def inverse_transform(self, data): - return data * self.std + self.mu - - -class Chomp(nn.Module): - """ - cuts padding part of sequence - inspired by https://github.com/locuslab/TCN - """ - def __init__(self, chomp_size: int, dims: int=1): - """ - args: - `chomp_size` - padding size to cut off - `dims` - number of working dimensionalities, which needed to be chomped - """ - super().__init__() - self.chomp_size = chomp_size - if dims not in (1, 2, 3): - raise NotImplementedError(f"Chomp layer for {dims = } not implemented") - self.dims = dims - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.dims == 1: - return x[..., : - self.chomp_size].contiguous() - if self.dims == 2: - return x[..., : - self.chomp_size, : - self.chomp_size].contiguous() - if self.dims == 3: - return x[..., : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous() - -class TemporalBlock(nn.Module): - """ - combination of (convolutional layer, chomp, relu, dropout) repeated `layers` times - adds additional convolutional layer if needed to downsample number of channels - inspired by https://github.com/locuslab/TCN - """ - def __init__(self, n_inputs: int, n_outputs: int, kernel_size: Union[int, tuple[int]], - stride: Union[int, tuple[int]], dilation: Union[int, tuple[int]], padding: Union[int, tuple[int]], - dropout: int = 0.2, dims: int = 1, layers: int = 2): - super().__init__() - - conv_layer = get_appropriate_conv_layer(dims) - self.padding = padding - self.dropout = dropout - - net = [] - for i in range(layers): - net.append(torch.nn.utils.weight_norm(conv_layer( - (n_inputs if i == 0 else n_outputs), n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation))) - if padding > 0: - net.append(Chomp(padding, dims)) - net.append(nn.ReLU()) - if dropout > 0: - net.append(nn.Dropout(dropout)) - self.net = nn.ModuleList(net) - - self.downsample = conv_layer(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None - self.relu = nn.ReLU() - self.init_weights() - - def init_weights(self): - """ - sets normal weight distribution for convolutional layers - """ - for i in range(0, len(self.net), 2 + (self.dropout > 0) + (self.padding > 0)): - self.net[i].weight.data.normal_(0, 0.5) - - if self.downsample is not None: - self.downsample.weight.data.normal_(0, 0.5) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - `x` in format [batch_size, channels, *other_dims] - """ - out = x - for i in range(len(self.net)): - out = self.net[i](out) - - res = x if self.downsample is None else self.downsample(x) - return out, self.relu(out + res) - - -class TimeDiffusionProjector(nn.Module): - """ - convolutional network, used as projector in TD - consists of temporal blocks with exponentially increasing padding/dilation parameters - """ - def __init__(self, input_dims: Union[list[int], tuple[int]], max_deg_constraint: int = 13, - conv_filters: int = 128, base_dropout: float = 0.05): - """ - args: - - `input_dims` - [channels, *dims] - needed for dynamical network building - best way to pass it as `x.shape` (without batches) - - `max_deg_constraint` - constraint to lessen network size, if not big enough will worsen model quality - number of temporal blocks in network will be (1 + max_deg_constraint) maximum - - `conv_filters` - number of convolutional filters for each layer - - `base_dropout` - dropout for first temporal block - """ - super().__init__() - - self.input_dims = input_dims - self.dims = len(input_dims) - 1 - self.channels = input_dims[0] - self.max_seq = max(input_dims[1:]) - self.max_deg = int(np.ceil(np.log2(self.max_seq))) - if max_deg_constraint < self.max_deg: - print(f"For better TimeDiffusion performance it's recommended to use max_deg_constraint ", end="") - print(f"with value{self.max_deg} for input with shape {input_dims}") - self.max_deg = max_deg_constraint - print(f"Setting current {self.max_deg = }") - - self.tcn = nn.ModuleList( - [TemporalBlock(self.channels, conv_filters, - kernel_size=1, stride=1, dilation=1, padding=0, dropout=base_dropout, dims=self.dims), - *[TemporalBlock(conv_filters, conv_filters, - kernel_size=2, stride=1, dilation=i, padding=i, dropout=0.0, dims=self.dims) - for i in [2 ** i for i in range(self.max_deg + 1)]] - ]) - - self.last = get_appropriate_conv_layer(self.dims)(conv_filters, self.channels, kernel_size=1, stride=1, dilation=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - skip_acc = None - for layer in self.tcn: - skip, x = layer(x) - if skip_acc is None: - skip_acc = skip - else: - skip_acc += skip - x = self.last(x + skip_acc) - return x - - -class TimeDiffusion(nn.Module): - """ - main model, uses projectors to create (q, k, v) for vanilla attention layer - """ - def __init__(self, *args, **params): - """ - `args`, `params` - parameters for projectors - """ - super().__init__() - self.key_proj = TimeDiffusionProjector(*args, **params) - self.val_proj = TimeDiffusionProjector(*args, **params) - self.query_proj = TimeDiffusionProjector(*args, **params) - - self.input_dims = self.key_proj.input_dims - self.dims = self.key_proj.dims - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # projections - key = self.key_proj(x) - val = self.val_proj(x) - query = self.query_proj(x) - - is_batched = self.dims + 2 == len(key.size()) - mat_mul = torch.bmm if is_batched else torch.matmul - - # flattening last dimensionalities in case of 2D and 3D input - # TODO: think of better solution - if self.dims > 1: - orig_shape = key.shape - new_shape = list(key.size()[: - self.dims]) + [np.prod(key.size()[ - self.dims:])] - - key = key.view(new_shape) - val = val.view(new_shape) - query = query.view(new_shape) - - # vanilla attenion - scores = mat_mul(query, key.transpose(- 2, - 1)) - weights = torch.nn.functional.softmax(scores, dim=1) - attention = mat_mul(weights, val) - - # back to original shape in case of 2D and 3D input - if self.dims > 1: - attention = attention.view(orig_shape) - - return attention +from .utils import count_params, DimUniversalStandardScaler +from .models import TimeDiffusion class TD(nn.Module): @@ -412,10 +193,10 @@ def restore(self, example: Union[None, np.array, torch.Tensor] = None, shape: Un if example is None: if shape is None: - raise ValueEror("Either `example` or `shape` should be passed") + raise ValueError("Either `example` or `shape` should be passed") torch.random.manual_seed(seed) - X = torch.rand(*dims).to(device=self.device(), dtype=self.dtype()) + X = torch.rand(*shape).to(device=self.device(), dtype=self.dtype()) else: if len(self.input_dims) != len(example.shape): raise ValueError(f"Model fitted with {len(self.input_dims)} dims, but got {len(example.shape)}") diff --git a/src/timediffusion/layers.py b/src/timediffusion/layers.py new file mode 100644 index 0000000..43517ef --- /dev/null +++ b/src/timediffusion/layers.py @@ -0,0 +1,83 @@ +from typing import Union + +import torch +from torch import nn + +from .utils import get_appropriate_conv_layer + + +class Chomp(nn.Module): + """ + cuts padding part of sequence + inspired by https://github.com/locuslab/TCN + """ + def __init__(self, chomp_size: int, dims: int=1): + """ + args: + `chomp_size` - padding size to cut off + `dims` - number of working dimensionalities, which needed to be chomped + """ + super().__init__() + self.chomp_size = chomp_size + if dims not in (1, 2, 3): + raise NotImplementedError(f"Chomp layer for {dims = } not implemented") + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 1: + return x[..., : - self.chomp_size].contiguous() + if self.dims == 2: + return x[..., : - self.chomp_size, : - self.chomp_size].contiguous() + if self.dims == 3: + return x[..., : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous() + +class TemporalBlock(nn.Module): + """ + combination of (convolutional layer, chomp, relu, dropout) repeated `layers` times + adds additional convolutional layer if needed to downsample number of channels + inspired by https://github.com/locuslab/TCN + """ + def __init__(self, n_inputs: int, n_outputs: int, kernel_size: Union[int, tuple[int]], + stride: Union[int, tuple[int]], dilation: Union[int, tuple[int]], padding: Union[int, tuple[int]], + dropout: int = 0.2, dims: int = 1, layers: int = 2): + super().__init__() + + conv_layer = get_appropriate_conv_layer(dims) + self.padding = padding + self.dropout = dropout + + net = [] + for i in range(layers): + net.append(torch.nn.utils.weight_norm(conv_layer( + (n_inputs if i == 0 else n_outputs), n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation))) + if padding > 0: + net.append(Chomp(padding, dims)) + net.append(nn.ReLU()) + if dropout > 0: + net.append(nn.Dropout(dropout)) + self.net = nn.ModuleList(net) + + self.downsample = conv_layer(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + """ + sets normal weight distribution for convolutional layers + """ + for i in range(0, len(self.net), 2 + (self.dropout > 0) + (self.padding > 0)): + self.net[i].weight.data.normal_(0, 0.5) + + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + `x` in format [batch_size, channels, *other_dims] + """ + out = x + for i in range(len(self.net)): + out = self.net[i](out) + + res = x if self.downsample is None else self.downsample(x) + return out, self.relu(out + res) diff --git a/src/timediffusion/models.py b/src/timediffusion/models.py new file mode 100644 index 0000000..0adac93 --- /dev/null +++ b/src/timediffusion/models.py @@ -0,0 +1,112 @@ +from typing import Union + +import numpy as np + +import torch +from torch import nn + +from .utils import get_appropriate_conv_layer +from .layers import TemporalBlock + + +class TimeDiffusionProjector(nn.Module): + """ + convolutional network, used as projector in TD + consists of temporal blocks with exponentially increasing padding/dilation parameters + """ + def __init__(self, input_dims: Union[list[int], tuple[int]], max_deg_constraint: int = 13, + conv_filters: int = 128, base_dropout: float = 0.05): + """ + args: + + `input_dims` - [channels, *dims] + needed for dynamical network building + best way to pass it as `x.shape` (without batches) + + `max_deg_constraint` - constraint to lessen network size, if not big enough will worsen model quality + number of temporal blocks in network will be (1 + max_deg_constraint) maximum + + `conv_filters` - number of convolutional filters for each layer + + `base_dropout` - dropout for first temporal block + """ + super().__init__() + + self.input_dims = input_dims + self.dims = len(input_dims) - 1 + self.channels = input_dims[0] + self.max_seq = max(input_dims[1:]) + self.max_deg = int(np.ceil(np.log2(self.max_seq))) + if max_deg_constraint < self.max_deg: + print(f"For better TimeDiffusion performance it's recommended to use max_deg_constraint ", end="") + print(f"with value{self.max_deg} for input with shape {input_dims}") + self.max_deg = max_deg_constraint + print(f"Setting current {self.max_deg = }") + + self.tcn = nn.ModuleList( + [TemporalBlock(self.channels, conv_filters, + kernel_size=1, stride=1, dilation=1, padding=0, dropout=base_dropout, dims=self.dims), + *[TemporalBlock(conv_filters, conv_filters, + kernel_size=2, stride=1, dilation=i, padding=i, dropout=0.0, dims=self.dims) + for i in [2 ** i for i in range(self.max_deg + 1)]] + ]) + + self.last = get_appropriate_conv_layer(self.dims)(conv_filters, self.channels, kernel_size=1, stride=1, dilation=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + skip_acc = None + for layer in self.tcn: + skip, x = layer(x) + if skip_acc is None: + skip_acc = skip + else: + skip_acc += skip + x = self.last(x + skip_acc) + return x + + +class TimeDiffusion(nn.Module): + """ + main model, uses projectors to create (q, k, v) for vanilla attention layer + """ + def __init__(self, *args, **params): + """ + `args`, `params` - parameters for projectors + """ + super().__init__() + self.key_proj = TimeDiffusionProjector(*args, **params) + self.val_proj = TimeDiffusionProjector(*args, **params) + self.query_proj = TimeDiffusionProjector(*args, **params) + + self.input_dims = self.key_proj.input_dims + self.dims = self.key_proj.dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # projections + key = self.key_proj(x) + val = self.val_proj(x) + query = self.query_proj(x) + + is_batched = self.dims + 2 == len(key.size()) + mat_mul = torch.bmm if is_batched else torch.matmul + + # flattening last dimensionalities in case of 2D and 3D input + # TODO: think of better solution + if self.dims > 1: + orig_shape = key.shape + new_shape = list(key.size()[: - self.dims]) + [np.prod(key.size()[ - self.dims:])] + + key = key.view(new_shape) + val = val.view(new_shape) + query = query.view(new_shape) + + # vanilla attenion + scores = mat_mul(query, key.transpose(- 2, - 1)) + weights = torch.nn.functional.softmax(scores, dim=1) + attention = mat_mul(weights, val) + + # back to original shape in case of 2D and 3D input + if self.dims > 1: + attention = attention.view(orig_shape) + + return attention diff --git a/src/timediffusion/utils.py b/src/timediffusion/utils.py new file mode 100644 index 0000000..7d7b701 --- /dev/null +++ b/src/timediffusion/utils.py @@ -0,0 +1,48 @@ +import numpy as np + +import torch +from torch import nn + + +def count_params(model: nn.Module) -> int: + """ + counts number of model parameters + """ + res = 0 + for param in model.parameters(): + res += np.prod(param.shape) + return res + + +def get_appropriate_conv_layer(dims: int) -> nn.Module: + """ + returns appropriate convolutional layer for certain number of dimensionalities + """ + if dims not in (1, 2, 3): + raise NotImplementedError(f"Convolutional layer for dimensionalty {dims} not implemented") + return {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}[dims] + + +class DimUniversalStandardScaler: + """ + Universal class for normal scaling data + """ + def __init__(self, eps=1e-9): + self.eps = eps + + def fit(self, data): + self.mu = data.mean() + self.std = data.std() + if isinstance(data, torch.Tensor): + self.mu = self.mu.item() + self.std = self.std.item() + + def transform(self, data): + return (data - self.mu) / (self.std + self.eps) + + def fit_transform(self, data): + self.fit(data) + return self.transform(data) + + def inverse_transform(self, data): + return data * self.std + self.mu