Skip to content

Commit

Permalink
Implement standard scaling
Browse files Browse the repository at this point in the history
* Also add it to the experiment script.
* [WIP] Still some outstanding issues with the architecture itself behaving properly, so not train-ready
  • Loading branch information
boykovdn committed Aug 31, 2024
1 parent 0c5fb36 commit 2f5bb64
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 17 deletions.
15 changes: 8 additions & 7 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,30 @@
from gwnet.datasets import METRLA
from gwnet.model import GraphWavenet
from gwnet.train import GWnetForecasting

# from gwnet.utils import StandardScaler
from gwnet.utils import TrafficStandardScaler


def train():
args = {"lr": 0.001, "weight_decay": 0.0001}
device = "gpu"
num_workers = 0 # NOTE Set to 0 for single thread debugging!

model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1)
plmodule = GWnetForecasting(args, model, missing_value=0.0)

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = pl.Trainer(
accelerator="cpu",
max_steps=100,
accelerator=device,
max_steps=10000,
gradient_clip_val=5.0, # TODO There was something about this in the code.
logger=tb_logger,
)

dataset = METRLA("./")
# import pdb; pdb.set_trace()
# scaler = StandardScaler(dataset.z_norm_mean, dataset.z_norm_std)
scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000)
plmodule.scaler = scaler
## TODO Parametrise.
loader = DataLoader(dataset, batch_size=1, num_workers=0)
loader = DataLoader(dataset, batch_size=32, num_workers=num_workers)

trainer.fit(model=plmodule, train_dataloaders=loader)

Expand Down
2 changes: 1 addition & 1 deletion src/gwnet/datasets/metrla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_targets_and_features(
num_timesteps_in: int,
num_timesteps_out: int,
interpolate: bool = False,
normalize: bool = True,
normalize: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
r"""
Build the input and output features.
Expand Down
12 changes: 7 additions & 5 deletions src/gwnet/train/gwnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from collections.abc import Callable
from typing import Any

import lightning.pytorch as pl
import torch
from torch_geometric.data import Data

from ..utils import create_mask
from ..utils import TrafficStandardScaler, create_mask


class GWnetForecasting(pl.LightningModule):
Expand All @@ -14,7 +13,7 @@ def __init__(
args: dict[str, Any],
model: torch.nn.Module,
missing_value: float = 0.0,
scaler: Callable[[torch.Tensor], torch.Tensor] | None = None,
scaler: TrafficStandardScaler | None = None,
) -> None:
r"""
Trains Graph wavenet for the traffic forecasting task.
Expand Down Expand Up @@ -65,12 +64,15 @@ def masked_mae_loss(
return torch.sum(loss[mask]) / num_terms

def training_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # noqa: ARG002
if self.scaler is not None:
# NOTE Normalise only the traffic feature, hardcoded as 0.
input_batch.x[:, 0] = self.scaler.transform(input_batch.x[:, 0])

targets = input_batch.y
out = self.model(input_batch)

if self.scaler is not None:
raise NotImplementedError()
# out = self.scaler.inverse_transform(out)
out = self.scaler.inverse_transform(out)

loss = self.masked_mae_loss(out, targets)
self.log("train_loss", loss)
Expand Down
4 changes: 2 additions & 2 deletions src/gwnet/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .utils import StandardScaler, create_mask
from .utils import TrafficStandardScaler, create_mask

__all__ = ["create_mask", "StandardScaler"]
__all__ = ["create_mask", "TrafficStandardScaler"]
27 changes: 25 additions & 2 deletions src/gwnet/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from __future__ import annotations

from random import randint

import torch
from torch_geometric.data import Dataset
from tqdm import tqdm


def create_mask(
Expand All @@ -13,11 +19,28 @@ def create_mask(
)


class StandardScaler:
def __init__(self, mu: torch.Tensor, std: torch.Tensor) -> None:
class TrafficStandardScaler:
def __init__(self, mu: float, std: float) -> None:
self.mu = mu
self.std = std

@classmethod
def from_dataset(
cls, dataset: Dataset, n_samples: int = 100
) -> TrafficStandardScaler:
traffic_vals: torch.Tensor | list[torch.Tensor] = [] # Holds tensors for stack.
for _ in tqdm(range(n_samples), desc="Initialising scaler statistics..."):
randidx = randint(0, len(dataset) - 1)
# NOTE Here 0th feature is hardcoded as the traffic, unravel
# into a sequence of values to be computed over.
traffic_vals.append(dataset[randidx].x[:, 0, :].ravel())

traffic_vals = torch.stack(traffic_vals)
mu = traffic_vals.mean()
std = traffic_vals.std()

return cls(mu, std)

def transform(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.mu) / self.std

Expand Down

0 comments on commit 2f5bb64

Please sign in to comment.