Skip to content

Commit

Permalink
added desc, 2d-3d support for attention and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timetoai committed Aug 31, 2023
1 parent 57034ec commit 5880c54
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "0.1"
authors = [
{ name="Shishkov Vladislav", email="[email protected]" },
]
description = ""
description = "TimeDiffusion - unified framework for multiple time series tasks"
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
numpy==1.23.5
torch==2.0.1+cu117
tqdm==4.65.0
7 changes: 6 additions & 1 deletion src/timediffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .timediffusion import TD, TimeDiffusionProjector, TimeDiffusion
from .timediffusion import TD, TimeDiffusionProjector, TimeDiffusion, count_params, DimUniversalStandardScaler

__all__ = [
# useful functions
"count_params",
# data processing
"DimUniversalStandardScaler",
# models
"TimeDiffusionProjector",
"TimeDiffusion",
"TD"
Expand Down
47 changes: 35 additions & 12 deletions src/timediffusion/timediffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def __init__(self, chomp_size: int, dims: int=1):

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 1:
return x[:, :, : - self.chomp_size].contiguous()
return x[..., : - self.chomp_size].contiguous()
if self.dims == 2:
return x[:, :, : - self.chomp_size, : - self.chomp_size].contiguous()
return x[..., : - self.chomp_size, : - self.chomp_size].contiguous()
if self.dims == 3:
return x[:, :, : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous()
return x[..., : - self.chomp_size, : - self.chomp_size, : - self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
"""
Expand Down Expand Up @@ -183,31 +183,54 @@ class TimeDiffusion(nn.Module):
"""
main model, uses projectors to create (q, k, v) for vanilla attention layer
"""
def __init__(self, **params):
def __init__(self, *args, **params):
"""
`params` - parameters for projectors
"""
super().__init__()
self.key_proj = TimeDiffusionProjector(**params)
self.val_proj = TimeDiffusionProjector(**params)
self.query_proj = TimeDiffusionProjector(**params)
self.key_proj = TimeDiffusionProjector(*args, **params)
self.val_proj = TimeDiffusionProjector(*args, **params)
self.query_proj = TimeDiffusionProjector(*args, **params)

self.input_dims = self.key_proj.input_dims
self.dims = self.key_proj.dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
# projections
key = self.key_proj(x)
val = self.val_proj(x)
query = self.query_proj(x)

scores = torch.bmm(query, key.transpose(1, 2))
is_batched = self.dims + 2 == len(key.size())
mat_mul = torch.bmm if is_batched else torch.matmul

# flattening last dimensionalities in case of 2D and 3D input
# TODO: think of better solution
if self.dims > 1:
orig_shape = key.shape
new_shape = list(key.size()[: - self.dims]) + [np.prod(key.size()[ - self.dims:])]

key = key.view(new_shape)
val = val.view(new_shape)
query = query.view(new_shape)

# vanilla attenion
scores = mat_mul(query, key.transpose(- 2, - 1))
weights = torch.nn.functional.softmax(scores, dim=1)
attention = torch.bmm(weights, val)
attention = mat_mul(weights, val)

# back to original shape in case of 2D and 3D input
if self.dims > 1:
attention = attention.view(orig_shape)

return attention


class TD(nn.Module):
"""
Class provides a convenient framework for effectively working with TimeDiffusion, encompassing all essential functions.
"""
def __init__(self, verbose: bool = False, seed=42, **params):
def __init__(self, verbose: bool = False, seed=42, *args, **params):
"""
args (mostly same as TimeDiffusionProjector):
Expand All @@ -228,7 +251,7 @@ def __init__(self, verbose: bool = False, seed=42, **params):
"""
super().__init__()
torch.random.manual_seed(seed)
self.model = TimeDiffusion(**params)
self.model = TimeDiffusion(*args, **params)
self.is_fitted = False
if verbose:
print(f"Created model with {count_params(self):.1e} parameters")
Expand Down Expand Up @@ -292,7 +315,7 @@ def _kl_div(x, y, eps=1e-3):
if isinstance(distance_loss, str):
if distance_loss not in ("MAE", "MSE"):
raise NotImplementedError(f"Distance loss {distance_loss} doesn't exist")
distance_loss = {"MAE": _mae, "MSE": _mse}
distance_loss = {"MAE": _mae, "MSE": _mse}[distance_loss]
elif not isinstance(distance_loss, nn.Module):
raise NotImplementedError(f"Distance loss should be 'MAE', 'MSE' or nn.Module, got {type(distance_loss)}")

Expand Down
89 changes: 86 additions & 3 deletions tests/test_timediffusion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,91 @@
import pytest

import numpy as np

import torch
from torch import nn

def test_count_params():
linear = nn.Linear(100, 100)
pass
from timediffusion import count_params, TimeDiffusionProjector, TimeDiffusion, TD


@pytest.mark.parametrize(
"in_features,out_features",
[(100, 100), (200, 200)],
)
def test_count_params_linear(in_features, out_features):
linear = nn.Linear(in_features, out_features)
assert count_params(linear) == in_features * out_features + out_features

@pytest.mark.parametrize(
"in_channels,out_channels,kernel_size,groups",
[
(3, 24, 3, 1),
(1, 10, 2, 1),
(3, 12, 4, 3)
],
)
def test_count_params_conv(in_channels, out_channels, kernel_size, groups):
conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, groups=groups)
assert count_params(conv) == out_channels * (in_channels // groups) * kernel_size + out_channels


@pytest.mark.parametrize("model_init", [TimeDiffusionProjector, TimeDiffusion])
@pytest.mark.parametrize("dims", [[1, 35], [1, 7, 7], [1, 5, 5, 5], [2, 35], [2, 7, 7], [2, 5, 5, 5]])
class TestTimeDiffusion:
def test_forward_pass(self, model_init, dims):
model = model_init(input_dims=dims)

# unbatched forward pass
data = torch.ones(*dims)
try:
res = model(data)
except Exception as e:
pytest.fail(f"Unbatched forward pass of {type(model).__name__} with {dims = } failed with exception: {e}")
assert data.shape == res.shape

# batched forward pass
data = torch.ones(1, *dims)
try:
res = model(data)
except Exception as e:
pytest.fail(f"Batched forward pass of {type(model).__name__} with {dims = } failed with exception: {e}")
assert data.shape == res.shape

def test_backward_pass(self, model_init, dims):
model = model_init(input_dims=dims)

# unbatched backward pass
data = torch.ones(*dims)
try:
res = model(data)
loss = (res - 1).mean().backward()
except Exception as e:
pytest.fail(f"Unbatched backward pass of {type(model).__name__} with {dims = } failed with exception: {e}")

# batched backward pass
data = torch.ones(1, *dims)
try:
res = model(data)
loss = (res - 1).mean().backward()
except Exception as e:
pytest.fail(f"Batched backward pass of {type(model).__name__} with {dims = } failed with exception: {e}")


@pytest.mark.parametrize("dims", [[1, 35], [1, 7, 7], [1, 5, 5, 5], [2, 35], [2, 7, 7], [2, 5, 5, 5]])
@pytest.mark.parametrize("mask_dropout", [None, 0.2])
class TestTD:
def test_fit(self, dims, mask_dropout):
model = TD(input_dims=dims)

data = np.ones(dims)
if mask_dropout is None:
mask = None
else:
np.random.seed(42)
mask = np.random.uniform(low=0., high=1.0, size=data.shape) < mask_dropout

try:
model.fit(data, mask=mask, epochs=1, batch_size=1, steps_per_epoch=2)
except Exception as e:
pytest.fail(f"TD fit with {dims = } failed with exception: {e}")

0 comments on commit 5880c54

Please sign in to comment.