Skip to content

Commit

Permalink
Lint and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
boykovdn committed Aug 15, 2024
1 parent a49078e commit d84b3a3
Show file tree
Hide file tree
Showing 14 changed files with 916 additions and 10 deletions.
29 changes: 29 additions & 0 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import lightning.pytorch as pl
from torch_geometric.loader.dataloader import DataLoader

from gwnet.datasets import METRLA
from gwnet.model import GraphWavenet
from gwnet.train import GWnetForecasting


def train():
args = {"lr": 0.01, "weight_decay": 0.0001}

model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1)
plmodule = GWnetForecasting(args, model, missing_value=0.0)

trainer = pl.Trainer(
accelerator="mps",
max_steps=100,
gradient_clip_val=5.0, # TODO There was something about this in the code.
)

dataset = METRLA("./")
## TODO Parametrise.
loader = DataLoader(dataset, batch_size=1, num_workers=0)

trainer.fit(model=plmodule, train_dataloaders=loader)


if __name__ == "__main__":
train()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ dependencies = [
"torch_geometric",
"numpy",
"pandas",
"scipy"
"scipy",
"lightning"
]

[project.optional-dependencies]
Expand Down
203 changes: 198 additions & 5 deletions quarto/notes.qmd

Large diffs are not rendered by default.

21 changes: 17 additions & 4 deletions src/gwnet/datasets/metrla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,25 @@ def __init__(
transform: Callable[[Data], Data] | None = None,
pre_transform: Callable[[Data], Data] | None = None,
pre_filter: Callable[[Data], Data] | None = None,
in_timesteps: int = 13,
out_timesteps: int = 12,
):
r"""
Args:
...
in_timesteps (int): The number of timesteps contained in each node.
out_timesteps (int): The number of timesteps ahead in the target.
"""
self.node_filename = "METR-LA.csv"
self.url_node = "https://zenodo.org/record/5724362/files/METR-LA.csv?download=1"
self.adj_filename = "adj_mx_METR-LA.pkl"
self.url_adj = (
"https://zenodo.org/record/5724362/files/adj_mx_METR-LA.pkl?download=1"
)
self.in_timesteps = in_timesteps
self.out_timesteps = out_timesteps
super().__init__(root, transform, pre_transform, pre_filter)
self.load(self.processed_paths[0])

Expand Down Expand Up @@ -100,8 +112,9 @@ def _get_targets_and_features(
stacked_target = np.concatenate(
[stacked_target[:, None, :], saw_tooth[:, None, :]], 1
)
# TODO Parameterise the precision.
features = [
torch.from_numpy(stacked_target[i : i + num_timesteps_in, :, :].T)
torch.from_numpy(stacked_target[i : i + num_timesteps_in, :, :].T).float()
for i in range(
stacked_target.shape[0] - num_timesteps_in - num_timesteps_out
)
Expand All @@ -113,7 +126,7 @@ def _get_targets_and_features(
0,
:,
].T
)
).float()
for i in range(
stacked_target.shape[0] - num_timesteps_in - num_timesteps_out
)
Expand All @@ -129,8 +142,8 @@ def process(self) -> None:

node_series = pd.read_csv(self.raw_paths[0], index_col=0)
features, targets = self._get_targets_and_features(
node_series, 12, 12
) # TODO Parametrise in/out n steps.
node_series, self.in_timesteps, self.out_timesteps
)
data_list = [
Data(x=features[i], edge_index=edges, edge_attr=edge_weights, y=targets[i])
for i in range(len(features))
Expand Down
3 changes: 3 additions & 0 deletions src/gwnet/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gwnet import GraphWavenet

__all__ = ["GraphWavenet"]
Loading

0 comments on commit d84b3a3

Please sign in to comment.