From 2f5bb64427abdebdbb4f3d2c0bc16b0e6482e255 Mon Sep 17 00:00:00 2001 From: Boyko Vodenicharski Date: Sat, 31 Aug 2024 19:19:41 +0100 Subject: [PATCH] Implement standard scaling * Also add it to the experiment script. * [WIP] Still some outstanding issues with the architecture itself behaving properly, so not train-ready --- experiments/reproduction.py | 15 ++++++++------- src/gwnet/datasets/metrla.py | 2 +- src/gwnet/train/gwnet.py | 12 +++++++----- src/gwnet/utils/__init__.py | 4 ++-- src/gwnet/utils/utils.py | 27 +++++++++++++++++++++++++-- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/experiments/reproduction.py b/experiments/reproduction.py index c390ca5..dcb26ad 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -5,29 +5,30 @@ from gwnet.datasets import METRLA from gwnet.model import GraphWavenet from gwnet.train import GWnetForecasting - -# from gwnet.utils import StandardScaler +from gwnet.utils import TrafficStandardScaler def train(): args = {"lr": 0.001, "weight_decay": 0.0001} + device = "gpu" + num_workers = 0 # NOTE Set to 0 for single thread debugging! model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1) plmodule = GWnetForecasting(args, model, missing_value=0.0) tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/") trainer = pl.Trainer( - accelerator="cpu", - max_steps=100, + accelerator=device, + max_steps=10000, gradient_clip_val=5.0, # TODO There was something about this in the code. logger=tb_logger, ) dataset = METRLA("./") - # import pdb; pdb.set_trace() - # scaler = StandardScaler(dataset.z_norm_mean, dataset.z_norm_std) + scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000) + plmodule.scaler = scaler ## TODO Parametrise. - loader = DataLoader(dataset, batch_size=1, num_workers=0) + loader = DataLoader(dataset, batch_size=32, num_workers=num_workers) trainer.fit(model=plmodule, train_dataloaders=loader) diff --git a/src/gwnet/datasets/metrla.py b/src/gwnet/datasets/metrla.py index 6a3e546..0917021 100644 --- a/src/gwnet/datasets/metrla.py +++ b/src/gwnet/datasets/metrla.py @@ -76,7 +76,7 @@ def _get_targets_and_features( num_timesteps_in: int, num_timesteps_out: int, interpolate: bool = False, - normalize: bool = True, + normalize: bool = False, ) -> tuple[np.ndarray, np.ndarray]: r""" Build the input and output features. diff --git a/src/gwnet/train/gwnet.py b/src/gwnet/train/gwnet.py index 8a1ba86..c4d4637 100644 --- a/src/gwnet/train/gwnet.py +++ b/src/gwnet/train/gwnet.py @@ -1,11 +1,10 @@ -from collections.abc import Callable from typing import Any import lightning.pytorch as pl import torch from torch_geometric.data import Data -from ..utils import create_mask +from ..utils import TrafficStandardScaler, create_mask class GWnetForecasting(pl.LightningModule): @@ -14,7 +13,7 @@ def __init__( args: dict[str, Any], model: torch.nn.Module, missing_value: float = 0.0, - scaler: Callable[[torch.Tensor], torch.Tensor] | None = None, + scaler: TrafficStandardScaler | None = None, ) -> None: r""" Trains Graph wavenet for the traffic forecasting task. @@ -65,12 +64,15 @@ def masked_mae_loss( return torch.sum(loss[mask]) / num_terms def training_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # noqa: ARG002 + if self.scaler is not None: + # NOTE Normalise only the traffic feature, hardcoded as 0. + input_batch.x[:, 0] = self.scaler.transform(input_batch.x[:, 0]) + targets = input_batch.y out = self.model(input_batch) if self.scaler is not None: - raise NotImplementedError() - # out = self.scaler.inverse_transform(out) + out = self.scaler.inverse_transform(out) loss = self.masked_mae_loss(out, targets) self.log("train_loss", loss) diff --git a/src/gwnet/utils/__init__.py b/src/gwnet/utils/__init__.py index bda856e..d644cce 100644 --- a/src/gwnet/utils/__init__.py +++ b/src/gwnet/utils/__init__.py @@ -1,3 +1,3 @@ -from .utils import StandardScaler, create_mask +from .utils import TrafficStandardScaler, create_mask -__all__ = ["create_mask", "StandardScaler"] +__all__ = ["create_mask", "TrafficStandardScaler"] diff --git a/src/gwnet/utils/utils.py b/src/gwnet/utils/utils.py index cbd51fe..77127c8 100644 --- a/src/gwnet/utils/utils.py +++ b/src/gwnet/utils/utils.py @@ -1,4 +1,10 @@ +from __future__ import annotations + +from random import randint + import torch +from torch_geometric.data import Dataset +from tqdm import tqdm def create_mask( @@ -13,11 +19,28 @@ def create_mask( ) -class StandardScaler: - def __init__(self, mu: torch.Tensor, std: torch.Tensor) -> None: +class TrafficStandardScaler: + def __init__(self, mu: float, std: float) -> None: self.mu = mu self.std = std + @classmethod + def from_dataset( + cls, dataset: Dataset, n_samples: int = 100 + ) -> TrafficStandardScaler: + traffic_vals: torch.Tensor | list[torch.Tensor] = [] # Holds tensors for stack. + for _ in tqdm(range(n_samples), desc="Initialising scaler statistics..."): + randidx = randint(0, len(dataset) - 1) + # NOTE Here 0th feature is hardcoded as the traffic, unravel + # into a sequence of values to be computed over. + traffic_vals.append(dataset[randidx].x[:, 0, :].ravel()) + + traffic_vals = torch.stack(traffic_vals) + mu = traffic_vals.mean() + std = traffic_vals.std() + + return cls(mu, std) + def transform(self, x: torch.Tensor) -> torch.Tensor: return (x - self.mu) / self.std