diff --git a/experiments/reproduction.py b/experiments/reproduction.py new file mode 100644 index 0000000..97f69a8 --- /dev/null +++ b/experiments/reproduction.py @@ -0,0 +1,29 @@ +import lightning.pytorch as pl +from torch_geometric.loader.dataloader import DataLoader + +from gwnet.datasets import METRLA +from gwnet.model import GraphWavenet +from gwnet.train import GWnetForecasting + + +def train(): + args = {"lr": 0.01, "weight_decay": 0.0001} + + model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1) + plmodule = GWnetForecasting(args, model, missing_value=0.0) + + trainer = pl.Trainer( + accelerator="mps", + max_steps=100, + gradient_clip_val=5.0, # TODO There was something about this in the code. + ) + + dataset = METRLA("./") + ## TODO Parametrise. + loader = DataLoader(dataset, batch_size=1, num_workers=0) + + trainer.fit(model=plmodule, train_dataloaders=loader) + + +if __name__ == "__main__": + train() diff --git a/pyproject.toml b/pyproject.toml index 93eff92..ee9eeea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ "torch_geometric", "numpy", "pandas", - "scipy" + "scipy", + "lightning" ] [project.optional-dependencies] diff --git a/quarto/notes.qmd b/quarto/notes.qmd index d0d017f..8a104b5 100644 --- a/quarto/notes.qmd +++ b/quarto/notes.qmd @@ -152,7 +152,7 @@ There are two specific components to this architecture, denoted as TCN and GCN i ![Diagram of a series of causal convolutions. In this architecture, the sequence of dilated convolutions are applied in subsequent layers.](./images/tcn_diagram.png) -TCN is motivated in the Wavenet [paper](https://arxiv.org/abs/1609.03499) as a way to aggregate information from an exponentially increasing receptive field. However, in graph wavenet it seems to have been applied with dilations of 1,2,1,2,..., so only aggregating information from a small number of timesteps back. +TCN is motivated in the Wavenet [paper](https://arxiv.org/abs/1609.03499) as a way to aggregate information from an exponentially increasing receptive field. However, in graph wavenet it seems to have been applied with dilations of 1,2,1,2,..., so only aggregating information from a small number of timesteps back. I am not sure why the authors have chosen not to retreive information from a larger temporal context, as is done in the Wavenet paper. The equation for the 'Gated TCN' component is the following: @@ -162,9 +162,9 @@ The equation for the 'Gated TCN' component is the following: Here, $g(\cdot)$ and $\sigma(\cdot)$ are the hyperbolic tan and sigmoid activations, $*$ denotes a convolution (the causal convolution), and $\odot$ is the elementwise product. The input $\mathbf{X}$ is of shape $\mathbb{R}^{N \times D \times S}$ for $N$ nodes, feature dimensionality $D$, and $S$ timesteps back. For example, our METR-LA datapoints will have shape $\mathbb{R}^{207 \times 2 \times 12}$. Each node at a given timestep has two features - the traffic speed, and a number with periodicity of 288 steps (24 hours), which provides context for the time of day. -The GCN module aggregates information between nodes along the traffic network. -The paper introduces three components to the GCN module. -Two of these are part of the so-called 'diffusion convolution' [@li2018diffusionconvolutionalrecurrentneural], which fundamentally tells us which nodes to include, and with what weight, when computing the node state update. It is computed by taking powers of the adjacency matrix. +The GCN module aggregates information between nodes along the traffic network. +The paper introduces three components to the GCN module. +Two of these are part of the so-called 'diffusion convolution' [@li2018diffusionconvolutionalrecurrentneural], which fundamentally tells us which nodes to include, and with what weight, when computing the node state update. It is computed by taking powers of the adjacency matrix. This will include edges from a central node to nodes that are 'reachable' after a certain number of hops, reminiscent of diffusion along the edges. This is done once along the direction of the edges, and once against the direction of the edges (by using $\mathbf{A}^{\top}$ instead of $\mathbf{A}$), and these are called the 'forward' and 'backward' diffusion. @@ -175,7 +175,7 @@ Each node of the graph is assigned an embedding $\mathbf{E} \in \mathbb{R}^{c}$ \tilde{\mathbf{A}}_{adp} = \text{SoftMax}(\text{ReLU}(\mathbf{E}_1 \mathbf{E}_2^{\top})) \end{equation} -Where $\mathbf{E}_1 \mathbf{E}_2^{\top}$ is an inner product. The adjacency matrix is normalised due to the softmax. This adaptive adjacency matrix is supposed to infer the 'transition matrix of a hidden diffusion process'. +Where $\mathbf{E}_1 \mathbf{E}_2^{\top}$ is an inner product. The adjacency matrix is normalised due to the softmax. This adaptive adjacency matrix is supposed to infer the 'transition matrix of a hidden diffusion process'. It makes sense, I imagine, that you can learn the graph wiring which yields the best performance for your task. When you use this matrix in addition to the original adjacency, then it might be learning a diff between the current graph, and possibly more optimal version? It would be interesting to see whether the adjacency matrix can learn to somewhat recover the original road network, if used on its own. @@ -187,6 +187,199 @@ The full equation with diffusion in both directions and the adaptive adjacency i \end{equation} Here $\mathbf{W}_{1,2,3}$ are parameters, $\mathbf{P}$ are the forward and backward diffusion transition matrices, $\tilde{\mathbf{A}}_{apt}$ is the adaptive adjacency matrix, and $\mathbf{X}$ is the input. The adjacency matrices are all raised to the power $k$. +The forward and backward diffusion matrices are computed by normalising the adjacency matrix and interpreting each unit weight as carrying a unit of probability of transition. + +\begin{equation} + \mathbf{P}_f = \frac{\mathbf{A}}{\sum_j\mathbf{A}_{ij}} +\end{equation} + +\begin{equation} + \mathbf{P}_b = \frac{\mathbf{A^{\top}}}{\sum_j\mathbf{A^{\top}}_{ij}} +\end{equation} + +The original graph wavenet implementation works on the entire road network without subsampling. Maybe we can implement subsampling, but for now let's follow how they've done it originally. +In principle, GWnet has a limited receptive field, so anything more than that amounts to batching (right?). +Allowing spatial subsampling should be nice for when training on bigger graphs. + +![Let's refactor some of the gems in the original implementation. This is difficult to read, but also exceptions should be used for exceptional circumstances, and not for controlling execution in high performance code.](./images/tryexcept.png) + +The TCN layer can be implemented using a Conv1d or Conv2d on the 2-by-12 'image'. +Latter is how they've done it in the original implementation. +It is a bit difficult to read though. +They say that the causality can be achieved by padding the input with eros on one side. +Note how they mention the increasing dilation factors for capturing exponential receptive field, but they then explicitly don't do that and never motivate their choice. + +A design choice they've made is to set the padding, stride, and dilation parameters such that the input sequence gets gradually reduced to 1. +I suppose I'll just do the same. +If you set the Conv1D kernel size to 2, dilation to 1, and add no padding, the output sequence will be shorter by 1. +If you do the same, but dilation 2, the output will be shorter by 2. +They keep all kernel sizes 2, and have a sequence of dilations $\left[1,2,1,2,1,2,1,2\right]$ for their 8 layers. +This means that for an input sequence of 13, the final layer will be dealing with an input sequence of 1. +Hence, our TCN-a and TCN-b will be made of ```nn.Conv1d(32, 32, ks=2, dilation=d)``` where the dilation will alternate between 1 and 2. + +For the GCN layer we can precompute the powers of the diffusion matrix and reuse, since the adjacency is the same (the entire road network). +I can implement this as a caching mechanism, the same way the PyG [GCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html) layer does. +There it caches the adjacency normalisation, but we can use a similar mechanism to store the powers of the adjacency. +These adjacencies won't change unless the graph structure changes, which we do not expect. + +But before I embark on implementing a new layer, could we possibly use the [GDN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.GDC.html) PyG implementation of "Diffusion Improves Graph Learning" [@gasteiger2022diffusionimprovesgraphlearning]? +It seems to do a bunch of other things like sparsification as built-in, so maybe not the right choice. +How about the [MixHop](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MixHopConv.html) module, introduced in [@abuelhaija2019mixhophigherordergraphconvolutional]? This one simply does the following operation: + +\begin{equation} + \mathbf{Z} = \Vert_{p \in P} \left( \hat{\mathbf{D}}^{-\frac{1}{2}}\hat{\mathbf{A}}\hat{\mathbf{D}}^{-\frac{1}{2}} \right)^{p}\mathbf{X}\mathbf{\Theta} +\end{equation} + +Where $\Vert$ is concatenation. We could reshape the concatenated vector and sum-reduce, yielding: + +\begin{equation} + \mathbf{Z} = \sum_{p \in P} \left( \hat{\mathbf{D}}^{-\frac{1}{2}}\hat{\mathbf{A}}\hat{\mathbf{D}}^{-\frac{1}{2}} \right)^{p}\mathbf{X}\mathbf{\Theta} +\end{equation} + +But unfortunately we can't seem to change the normalisation, it is built-in symmetric. +Note how close this is to the diffusion convolution: + +\begin{equation} + \mathbf{Z} = \sum_{k=0}^{K}\mathbf{P}^k \mathbf{X} \mathbf{W}_1 +\end{equation} + +If we were to set $P = (0,1,2)$, and replace $\left( \hat{\mathbf{D}}^{-\frac{1}{2}}\hat{\mathbf{A}}\hat{\mathbf{D}}^{-\frac{1}{2}} \right)^{p}$ with $\left(\mathbf{A}\mathbf{D}^{-1}\right)^p$, then we should have exactly the diffusion convolution as defined in the paper. +Maybe implementing the exact Diffusion module can be future work, let's just use MixHop. +The normalisation won't be right, but the support for the convolution should be the same. + +![Inspiring productivity: Author unknown](./images/mediocrates.png) + +- [ ] Implement Diffusion conv with the correct normalisation. + +So to summarise, use ```torch_geometric.nn.conv.MixHopConv``` as originally implemented. +It transforms $(*, C_{in}) \longrightarrow (*, |P|C_{out})$ due to concatenation. +Follow up the layer with reshape $(*, |P|, C_{out})$ then sum along $|P|$ to get back $(*, C_{out})$. + +:::{.callout-note} +In presentation, would be worth to explain PyG batching. +::: + +Annoyingly, I'm trying to use MixHop, but it doesn't know what to do with nodes that have more than one feature dimension. +Ok, so looks like we'll be implementing Diffusion from scratch. +Might as well go through how to implement a PyG module.. Mediocrates has lost on this occasion. + +MixHop is a mess when dealing with time series in the nodes. So build a custom diffusion using matrix multiplication... + +- [ ] Implement two versions of Diffusion. One in the graph and one in the node view. + +Ok, but on the other hand, building Diffusion in the graph view, you run into an issue if you also use PyG Data types. +This is because the way PyG does batching is to create block-diagonal adjacency instead of creating a batch dimension. +The PyG way is good for when the batched graphs have different sizes, but in my case I'd need to either modify the adjacency for the correct batch size, or stop using PyG Data. +I don't like these two solutions, because they are overly specific to this problem, and the code will start to look a lot like the original GWnet repo. +On the other hand, I could implement a Diffusion PyG module that works with the (C,T) feature tensors, which MixHop doesn't seem to like. +This approach seems to be the best in terms of design, because my modification is limited to subclassing MessagePassing for my particular problem. +It is also educational, because I get to introduce how to build a MessagePassing GNN layer. +However, unless I implement caching of the adjacency powers, every layer would have to compute three steps of message passing. +Luckily, I think it would be straightforward to cache, since you always pass the edge index to a graph conv layer, hence you can pass a different adjacency than the input's. + +Let's break down the full equation: + +\begin{equation} + \mathbf{Z} = \sum_{k=0}^{K}\mathbf{P}^k_f \mathbf{X} \mathbf{W}_{k}^{(f)} + \mathbf{P}^k_b \mathbf{X} \mathbf{W}_{k}^{(b)} + \tilde{\mathbf{A}}^k_{apt} \mathbf{X} \mathbf{W}_{k}^{(apt)} +\end{equation} + +The new state of a node $\mathbf{Z}$ is the sum of $K+1$ terms of the form $\mathbf{A}_{k}\mathbf{X}\mathbf{W}_{k}$, for some adjacency $\mathbf{A}_k$ and dense update weights $\mathbf{W}_k$. +For $k=0$ the graph convolution becomes simply a Linear update to each node, since $\mathbf{A}^0 = \mathbf{I}$, hence the three terms become: + +\begin{equation} + \mathbf{X}\mathbf{W}^{(f)}_0 + \mathbf{X}\mathbf{W}^{(b)}_0 + \mathbf{X}\mathbf{W}^{(adp)}_0 = \mathbf{X}\left( \mathbf{W}^{(f)}_0 + \mathbf{W}^{(b)}_0 + \mathbf{W}^{(adp)}_0\right) = \mathbf{X}\mathbf{W} +\end{equation} + +I'm not sure why the zeroth term is added to the sum, but it looks like we can instead start the sum from $k=1$ and add one Linear layer that updates each node without the graph signal. Hence, the update equation becomes: + +\begin{equation} + \mathbf{Z} = \sum_{k=1}^{K}\left[\mathbf{P}^k_f \mathbf{X} \mathbf{W}_{k}^{(f)} + \mathbf{P}^k_b \mathbf{X} \mathbf{W}_{k}^{(b)} + \tilde{\mathbf{A}}^k_{apt} \mathbf{X} \mathbf{W}_{k}^{(apt)}\right] + \mathbf{X}\mathbf{W}_0 + \mathbf{b} +\end{equation} + +Where $\mathbf{X}\mathbf{W}_0$ is the update from $k=0$ as discussed above, and $\mathbf{b}$ is the layer's bias. +The term $\mathbf{X}\mathbf{W}_0 + \mathbf{b}$ together can be built as a ```nn.Linear``` with ```bias=True```. + +Luckily, if we look at how the PyG [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html) layer works, we notice that you can set ```normalize=False```, which configures the layer to only do: + +```python +def forward(self, x: Tensor, edge_index: Adj, + edge_weight: OptTensor = None) -> Tensor: + + x = self.lin(x) + + # propagate_type: (x: Tensor, edge_weight: OptTensor) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight) + + if self.bias is not None: + out = out + self.bias + + return out +``` + +So further setting ```bias=False``` means the layer simply does $\mathbf{A}\mathbf{X}\mathbf{W}$, where the weights associated with $\mathbf{A}$ are passed as ```edge_weight```. +This means that our graph convolution module can be built up from $KM$ ```GCNConv``` Modules, each using a separate adjacency matrix. +The integer $M \in \left[1, 3\right]$ is how many of the adjacencies: forward, backward diffusion and adaptive the model is using. +These adjacency matrices could be stored in the parent Module, and our custom graph conv Module could be instantiated with a reference to its parent. +Note also that on every forward pass we would have to compute $\mathbf{A}_{apt}$, because backpropagation will have changed the node embeddings. +And, of course, we would also have to register the node embedding matrix as a ```Parameter``` which requires grad. +I wonder how this would change whether the model is inductive or transductive? +You definitely won't be able to use a new node's embeddings. + +:::{.callout-note} +Ok, so all the GCNs in our Graph Wavenet take the input and perform the propagation operation. +These operations can be done in parallel, similarly to multihead attention. +So, is there a way to do this easily in parallel, or does this start to enter the realm of optimisation..? +::: + +Passing the adjacency into the GWnet constructor seemed as the way to precompute, but I'd have to pass the adjacency of a BatchedData, otherwise the batching is not taken care of. +It might be easier to precompute the diffusion adjacencies on each forward pass, so that you directly get access to the batched adjacency index and weights. +It also means that this should trivially work if subgraphing, because the precomputation is done once the subgraph is sampled. +But in either case, there is no getting away from the fact that the adaptive adjacency has to be computed on each forward, that it is not batched, and that I will likely have to make it fit the batch manually. +Luckily, there is ```torch.block_diag```, which might be able to do it easily, but it does not support sparse inputs [yet](https://github.com/pytorch/pytorch/issues/31942) and might be inefficient to go to dense for the entire batch, then to sparse again. + +- [ ] Implement adjacency precomputation per forward pass +- [ ] Implement adaptive adjacency computation per forward pass using ```torch.block_diag```. +- [ ] Look whether I can contribute to [https://github.com/pytorch/pytorch/issues/31942](https://github.com/pytorch/pytorch/issues/31942) + +Ok so today I built the adjacency precomputation, but I'm not sure whether the powers are correct, because no new edges are added to the index. +Looks like taking powers of sparse matrices might be a little more involved than simply using Python syntax. + +- [ ] Write GH Pytorch issue requesting powers for sparse matrix? +- [ ] GH Pytorch sparse division by vector? + +Taking the power of the dense batch adjacency and then saving it once? +But it also looks like we can take powers of the COO matrix, so maybe it will work straightaway if I pass the real COO matrix. +The adaptive bit I'd have to compute densely and take powers, but then I can turn into a batch block diagonal using utility functions. +So maybe the way to go is dense powers of forward and backward once at beginning of training, then per forward pass update the adaptive adjacency. + +I've precomputed the adjacencies, but still there is the problem of GCN not knowing how to handle (N,C,L) tensors. +I'll likely have to implement a GCN that does that correctly. +By precomputing, I think we can run the sum concurrently, but I don't know how to do that. +Instead, I could run the newly implemented GCN as a loop over the k-hops. + +- [ ] Build a custom GCN layer that can handle the time series node type. +- [ ] Open PyG issue for (N, L, C) features breaking GCNConv. + +Ok so my layer seems to work if using ```torch.nn.Linear``` with transposition, but ```torch_geometric.nn.dense.Linear``` with transposition doesn't work? +What's going on? +Anyway, it looks like it's doing the correct thing after essentially replacing the PyG Linear with the torch Linear in my custom layer. +Also fixed bug when computing the transition matrix, so now it makes sense when I check the output. + +At this point, the residual module looks correct and it produces outputs of a sensible shape. +Ok so what are the skip convolutions, and how do I select the correct vector to add to the skip output? +I've chosen to add to the skipped output the value of only the last element of the temporal sequence. +This is in part because that's how I interpret the official implementation. +In this way, all previous timesteps act as context for the final one, a bit like how ViT has a special classification token that holds the output, and uses the other tokens as context. + +#### Training script + +I'll use torch lightning. + +- [ ] Introduce torch lightning. + +There are a couple of points to note about the training procedure. +First, loss is not calculated on the missing values, which in METR-LA are denoted as 0.0. +Second, the MAE in literature is calculated in the unnormalised range, so keep that in mind when reporting it. ### References diff --git a/src/gwnet/datasets/metrla.py b/src/gwnet/datasets/metrla.py index 1eee929..6a3e546 100644 --- a/src/gwnet/datasets/metrla.py +++ b/src/gwnet/datasets/metrla.py @@ -15,13 +15,25 @@ def __init__( transform: Callable[[Data], Data] | None = None, pre_transform: Callable[[Data], Data] | None = None, pre_filter: Callable[[Data], Data] | None = None, + in_timesteps: int = 13, + out_timesteps: int = 12, ): + r""" + Args: + ... + + in_timesteps (int): The number of timesteps contained in each node. + + out_timesteps (int): The number of timesteps ahead in the target. + """ self.node_filename = "METR-LA.csv" self.url_node = "https://zenodo.org/record/5724362/files/METR-LA.csv?download=1" self.adj_filename = "adj_mx_METR-LA.pkl" self.url_adj = ( "https://zenodo.org/record/5724362/files/adj_mx_METR-LA.pkl?download=1" ) + self.in_timesteps = in_timesteps + self.out_timesteps = out_timesteps super().__init__(root, transform, pre_transform, pre_filter) self.load(self.processed_paths[0]) @@ -100,8 +112,9 @@ def _get_targets_and_features( stacked_target = np.concatenate( [stacked_target[:, None, :], saw_tooth[:, None, :]], 1 ) + # TODO Parameterise the precision. features = [ - torch.from_numpy(stacked_target[i : i + num_timesteps_in, :, :].T) + torch.from_numpy(stacked_target[i : i + num_timesteps_in, :, :].T).float() for i in range( stacked_target.shape[0] - num_timesteps_in - num_timesteps_out ) @@ -113,7 +126,7 @@ def _get_targets_and_features( 0, :, ].T - ) + ).float() for i in range( stacked_target.shape[0] - num_timesteps_in - num_timesteps_out ) @@ -129,8 +142,8 @@ def process(self) -> None: node_series = pd.read_csv(self.raw_paths[0], index_col=0) features, targets = self._get_targets_and_features( - node_series, 12, 12 - ) # TODO Parametrise in/out n steps. + node_series, self.in_timesteps, self.out_timesteps + ) data_list = [ Data(x=features[i], edge_index=edges, edge_attr=edge_weights, y=targets[i]) for i in range(len(features)) diff --git a/src/gwnet/model/__init__.py b/src/gwnet/model/__init__.py new file mode 100644 index 0000000..581dd65 --- /dev/null +++ b/src/gwnet/model/__init__.py @@ -0,0 +1,3 @@ +from .gwnet import GraphWavenet + +__all__ = ["GraphWavenet"] diff --git a/src/gwnet/model/gwnet.py b/src/gwnet/model/gwnet.py new file mode 100644 index 0000000..65c222c --- /dev/null +++ b/src/gwnet/model/gwnet.py @@ -0,0 +1,429 @@ +from collections import OrderedDict +from typing import Any + +import torch +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.utils import to_dense_adj + +# from torch_geometric.nn.conv import GCNConv +from torch_geometric.utils.sparse import dense_to_sparse + +from .layer import TempGCN + + +class GatedTCN(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 2, + dilation: int = 1, + ) -> None: + r""" + The GatedTCN module from Graph Wavenet. + """ + super().__init__() + self.filter_conv = torch.nn.Conv1d( + in_channels, out_channels, kernel_size, dilation=dilation + ) + self.gate_conv = torch.nn.Conv1d( + in_channels, out_channels, kernel_size, dilation=dilation + ) + + self.sigmoid = torch.nn.Sigmoid() + self.tanh = torch.nn.Tanh() + + def forward(self, batch: Data) -> Data: + r""" + This module uses Conv1d, so the sequence dimension is expected (*, C, L), + for channels C, sequence length L. + """ + assert len(batch.x.shape) == 3, f"Expected (N,C,L), got {batch.shape}" + # C == in_channels + N, C, L = batch.x.shape + + # The Conv can be done concurrently across each node with dimensionality C, + # series len L. + out = self.tanh(self.filter_conv(batch.x)) * self.sigmoid( + self.gate_conv(batch.x) + ) + # Last dimension is inferred, it will be the sequence length reduced. + batch.x = out + return batch + + +class DiffusionConv(torch.nn.Module): + def __init__( + self, + args: dict[str, Any], + in_channels: int, + out_channels: int, + # k_hops: int = 3, + bias: bool = True, + ) -> None: + r""" + Args: + args (dict): Parameters from the parent modules. + + in_channels (int): Input dimensionality + + out_channels (int): Output dimensionality + + k_hops (int): The highest power of the transition matrix. This is + equivalent to the number of hops away signal is propagated from. + """ + super().__init__() + + self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + + # Create one GCN per hop and diffusion direction. + gcn_dict = {} + if args["fwd"]: + for k in range(1, args["k_hops"] + 1): + gcn_dict[f"fwd_{k}"] = TempGCN( + in_channels, out_channels, bias=False, node_dim=0 + ) + + if args["bck"]: + for k in range(1, args["k_hops"] + 1): + gcn_dict[f"bck_{k}"] = TempGCN( + in_channels, out_channels, bias=False, node_dim=0 + ) + + if args["adp"]: + for k in range(1, args["k_hops"] + 1): + gcn_dict[f"adp_{k}"] = TempGCN( + in_channels, out_channels, bias=False, node_dim=0 + ) + + self.gcns = torch.nn.ModuleDict(gcn_dict) + + def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: + r""" + Args: + x (Data): Input graph. + """ + out_sum = 0 + for adj_name, adj_index in cached_params["adj_indices"].items(): + gcn_ = self.gcns[adj_name] + # NOTE Some trickery here. In TempGCN the dense layer works on the + # last dimension. The layer is also configured to aggregate along + # the 0th dim rather than the default -2. x.x is (N, C, L), so we + # transpose C_in, L then transpose back to C_out, L. + out_sum += gcn_( + x.x.transpose(1, 2), adj_index, cached_params["adj_weights"][adj_name] + ).transpose(1, 2) + + return x + + +class STResidualModule(torch.nn.Module): + def __init__( + self, + args: dict[str, Any], + in_channels: int, + interm_channels: int, + out_channels: int, + dilation: int = 1, + kernel_size: int = 2, + ): + r""" + Wraps the TCN and GCN modules. + + Args: + args (dict): Contains parameters passed from the parent module, namely + flags showing which adjacency matrices to expect and initialise + parameters for. + """ + super().__init__() + + self.tcn = GatedTCN( + in_channels, interm_channels, kernel_size=kernel_size, dilation=dilation + ) + self.gcn = DiffusionConv( + args, interm_channels, out_channels + ) # TODO # interm -> out channels, diffusion_hops + + def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data: + r""" + Apply the gated TCN followed by GCN. + + Note that TCN will act on each node across the batches independently, + so we can reshape to merge the nodes into the batch dimension. When + applying GCN, we will need to keep the batch and node dimension + separate, because the adjacency matrix will define which nodes talk + to each other, and the node index matters. + + Args: + x: The batched Data object. + """ + # TCN works on the features alone, and handles the (N,C,L) shape + # internally. + tcn_out = self.tcn(x) + + return self.gcn(tcn_out, cached_adj) + + +class GraphWavenet(torch.nn.Module): + def __init__( + self, + in_channels: int = 2, + out_channels: int = 12, + dilation_channels: int = 32, + residual_channels: int = 32, + skip_channels: int = 256, + end_channels: int = 512, + n_layers: int = 8, + dilations_override: None | list[int] = None, + k_diffusion_hops: int = 3, + adaptive_embedding_dim: int | None = None, + n_nodes: int | None = None, + forward_diffusion: bool = True, + backward_diffusion: bool = True, + ): + r""" + Initialise the GWnet model. + + It requires the road network adjacency matrix, because it runs + on the entire road network for every prediction. It always + predicts for all nodes at once. The powers of the adjacency + used in the diffusion layer can be pre-computed, hence the adj + is required at instantiation. + """ + super().__init__() + + if adaptive_embedding_dim is not None: + if n_nodes is None: + adp_err_msg = "If adaptive_embedding_dim is passed, \ + n_nodes should also be passed." + raise Exception(adp_err_msg) + + self.node_embeddings = torch.nn.Parameter( + torch.rand(n_nodes, adaptive_embedding_dim) + ) + adp = True + + else: + self.register_parameter("node_embeddings", None) + adp = False + + # This model accepts the entire road network, hence can cache + # the diffusion adjacency matrices, doesn't have to take their + # powers on every forward pass. + + # Some of these parameters + self.global_elements: dict[str, Any] = { + "fwd": forward_diffusion, + "bck": backward_diffusion, + "adp": adp, + "adj_indices": None, + "adj_weights": None, + "k_hops": k_diffusion_hops, + } + + # Set the dilations to the default in the paper, unless different + # ones are given. + dilations = [1, 2, 1, 2, 1, 2, 1, 2] + if dilations_override: + assert len(dilations_override) == n_layers + dilations = dilations_override + + # Initialize the linear modules. + # NOTE For Linear, the channels should come last. + self.init_linear = torch.nn.Linear(in_channels, residual_channels) + + # Initialize the residual layers. + res_layers = [] + skip_convs = [] + for idx_layer in range(n_layers): + dilation = dilations[idx_layer] + res_layers.append( + ( + f"residual_layer_{idx_layer}", + STResidualModule( + self.global_elements, + residual_channels, + dilation_channels, + residual_channels, + dilation=dilation, + ), + ) + ) + + skip_convs.append( + ( + f"skip_conv_{idx_layer}", + torch.nn.Conv1d(residual_channels, skip_channels, kernel_size=1), + ) + ) + + self.residual_net = torch.nn.Sequential(OrderedDict(res_layers)) + self.skip_convs = torch.nn.Sequential(OrderedDict(skip_convs)) + + self.relu = torch.nn.ReLU() + self.out_linear_0 = torch.nn.Linear(skip_channels, end_channels) + self.out_linear_1 = torch.nn.Linear(end_channels, out_channels) + + def _update_adp_adj(self, batch_size: int, k_hops: int) -> None: + r""" + Update the adaptive adjacency matrix. + + Uses the node embeddings to compute an adjacency matrix using inner + products. + + Args: + batch_size (int): The input batch size. This is needed to + expand the adjacency as a block diagonal matrix across + the batch. + """ + + if ( + self.global_elements["adj_indices"] is None + or self.global_elements["adj_weights"] is None + ): + self.global_elements["adj_indices"] = {} + self.global_elements["adj_weights"] = {} + + # (N, C) @ (C, N) -> (N, N) + adp_adj = F.softmax( + F.relu(self.node_embeddings @ self.node_embeddings.T), dim=1 + ) + adp_adj_dense_batch = torch.block_diag(*[adp_adj] * batch_size) + + for k in range(1, k_hops + 1): + adp_dense_power = adp_adj_dense_batch**k + adp_indices, adp_weights = dense_to_sparse(adp_dense_power) + self.global_elements["adj_indices"][f"adp_{k}"] = adp_indices + self.global_elements["adj_weights"][f"adp_{k}"] = adp_weights + + def _precompute_gcn_adjs( + self, + adj: torch.Tensor, + weights: torch.Tensor, + k_hops: int, + fwd: bool = True, + bck: bool = True, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + r""" + Return the adj matrix along with powers of it. + + Args: + adj (torch.Tensor): The edge index of the dataset + + weights (torch.Tensor): The edge attrs (weights) of the graph. + + k_hops (int): If k_hops is 1, then the one-hop neighbourhood is included, + along with the identity. If k_hops is 0, then the adjacencies are an + empty list. + + fwd (bool): Whether to compute the powers of the forward transition matrix. + + bck (bool): Whether to compute the powers of the backward transition matrix. + + Returns: + dict, dict + """ + adj_dense = to_dense_adj(adj, edge_attr=weights)[0] + + # Compute the transition matrices from the adjacency. + transition_dense = adj_dense / adj_dense.sum(1).repeat((len(adj_dense), 1)).T + + index_dict = {} + weight_dict = {} + for k in range(1, k_hops + 1): + if fwd: + adj_dense_power = transition_dense**k + adj_sparse_indices, adj_sparse_weights = dense_to_sparse( + adj_dense_power + ) + index_dict[f"fwd_{k}"] = adj_sparse_indices + weight_dict[f"fwd_{k}"] = adj_sparse_weights + + if bck: + adj_T_dense_power = transition_dense.T**k + adj_T_sparse_indices, adj_T_sparse_weights = dense_to_sparse( + adj_T_dense_power + ) + index_dict[f"bck_{k}"] = adj_T_sparse_indices + weight_dict[f"bck_{k}"] = adj_T_sparse_weights + + return index_dict, weight_dict + + def forward(self, batch: Data) -> torch.Tensor: + r""" + Run the forward pass of the GraphWavenet model. + + The input data comes as a batch of Data objects, but throughout the + execution of the model, the features might be reshaped and operated + on by torch Modules. The GCN module requires the adjacency matrix + and works on the Data batch itself. + + Args: + batch: Input batch + """ + batched_index, batched_weights = batch.edge_index, batch.edge_attr + + # The powers of the graph transition matrix are computed once and + # stored globally. + if ( + self.global_elements["adj_indices"] is None + and self.global_elements["adj_weights"] is None + ): + self.global_elements["adj_indices"], self.global_elements["adj_weights"] = ( + self._precompute_gcn_adjs( + batched_index, + batched_weights, + self.global_elements["k_hops"], + fwd=self.global_elements["fwd"], + bck=self.global_elements["bck"], + ) + ) + + self._update_adp_adj(batch.batch.max() + 1, self.global_elements["k_hops"]) + + # x_dict = batch.x_dict + # edge_index_dict = batch.edge_index_dict + # edge_attr_dict = ( + # batch.edge_attr_dict + # ) # TODO Should I check if this is present first? + + # Here transpose (N, C, L) -> (N, L, C) then back. + # Linear expects channels at the end. + batch.x = self.init_linear(batch.x.transpose(1, 2)).transpose(1, 2) + + x_skipped: None | torch.Tensor = None + x_intermetiate: None | torch.Tensor = None + for layer_idx in range(len(self.residual_net)): + residual_layer = self.residual_net[layer_idx] + skip_conv = self.skip_convs[layer_idx] + + if x_intermetiate is None: + # keep interlayer result. + x_intermediate = residual_layer(batch, self.global_elements) + else: + # Update the dict to record the latest layer's output. + x_intermediate = residual_layer(x_intermediate, self.global_elements) + + if x_skipped is None: + # Create the output tensor that will store the sum of the skip + # connections. + # TODO Make sure I am using Conv1d correctly, and that L has to be + # the final dimension. + # import pdb; pdb.set_trace() + # TODO Off the top of my head, you take the final sequence element + # and add to the skip vector. + # TODO Should check this. + # NOTE Selecting final element, then expanding the axis back out. + # TODO Can probably replace this Conv1D with Linear, should be more + # more readable. + x_skipped = skip_conv(x_intermediate.x[..., -1][..., None])[..., 0] + else: + x_skipped = ( + x_skipped + skip_conv(x_intermediate.x[..., -1][..., None])[..., 0] + ) + + x_out = self.relu(x_skipped) + x_out = self.out_linear_0(x_out) + x_out = self.relu(x_out) + return self.out_linear_1(x_out) diff --git a/src/gwnet/model/layer/__init__.py b/src/gwnet/model/layer/__init__.py new file mode 100644 index 0000000..6dc94ff --- /dev/null +++ b/src/gwnet/model/layer/__init__.py @@ -0,0 +1,3 @@ +from .tempo_gcn import TempGCN + +__all__ = ["TempGCN"] diff --git a/src/gwnet/model/layer/tempo_gcn.py b/src/gwnet/model/layer/tempo_gcn.py new file mode 100644 index 0000000..ceb69fe --- /dev/null +++ b/src/gwnet/model/layer/tempo_gcn.py @@ -0,0 +1,63 @@ +import torch +from torch_geometric.nn import MessagePassing + + +class TempGCN(MessagePassing): + def __init__( # type: ignore[no-untyped-def] + self, + in_channels: int, + out_channels: int, + bias: bool = False, + **kwargs, + ) -> None: + r""" + Very cut-down version of GCNConv that works on (N, C, L) inputs. + + This class is built, because the provided modules from PyG don't + seem to be happy working on the feature tensors associated with + the nodes of GraphWavenet. They seem to instead have been built + for (N, C) feature tensors, where N is the number of nodes, and + C is the dimensionality of the feature vectors. In our case, we + work with (N, C, L), where L is the sequence length of the agglo- + merated time series features. + + Args: + in_channels (int): The input channels, C_in in (N, C_in, L) + + out_channels (int): The output channels, C_out in (N, C_out, L) + + bias (bool): Whether to add a bias term to the linear trans- + form. Default to False, because in my use case I handle + the bias in a separate layer. + + **kwargs: Options for MessagePassing. + """ + kwargs.setdefault("aggr", "add") + + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_weight: torch.Tensor | None = None, + ) -> torch.Tensor: + r""" + Run the TempGCN layer. + + Note here that x has a nonstandard dimensionality. + + Args: + x (torch.Tensor): (N, C, L) node features. + + edge_index (torch.Tensor): (2, N) Edge index. + + edge_weight (torch.Tensor): (N, 1) + """ + z = self.propagate(edge_index, x=x, edge_weight=edge_weight) + return self.linear(z) diff --git a/src/gwnet/train/__init__.py b/src/gwnet/train/__init__.py new file mode 100644 index 0000000..d6341f4 --- /dev/null +++ b/src/gwnet/train/__init__.py @@ -0,0 +1,3 @@ +from .gwnet import GWnetForecasting + +__all__ = ["GWnetForecasting"] diff --git a/src/gwnet/train/gwnet.py b/src/gwnet/train/gwnet.py new file mode 100644 index 0000000..8a1ba86 --- /dev/null +++ b/src/gwnet/train/gwnet.py @@ -0,0 +1,78 @@ +from collections.abc import Callable +from typing import Any + +import lightning.pytorch as pl +import torch +from torch_geometric.data import Data + +from ..utils import create_mask + + +class GWnetForecasting(pl.LightningModule): + def __init__( + self, + args: dict[str, Any], + model: torch.nn.Module, + missing_value: float = 0.0, + scaler: Callable[[torch.Tensor], torch.Tensor] | None = None, + ) -> None: + r""" + Trains Graph wavenet for the traffic forecasting task. + + Args: + args (dict): training args. + + model (Module): The Graph Wavenet model. + + missing_value (float): Defaults to 0.0, this is the value which + denotes a missing value in the dataset (relevant to METR-LA), + and no gradient is calculated to predict these values. + + scaler: A function that inverts the normalisation of the data, + used to produce MAE losses comparable to literature. + """ + super().__init__() + + self.args = args + self.model = model + self.scaler = scaler + self.missing_value = missing_value + + def configure_optimizers(self) -> torch.optim.Optimizer: + return torch.optim.Adam( + self.model.parameters(), + lr=self.args["lr"], + weight_decay=self.args["weight_decay"], + ) + + def masked_mae_loss( + self, preds: torch.Tensor, targets: torch.Tensor + ) -> torch.Tensor: + r""" + Calculates the masked MAE loss. + + The values which match the class missing_value do not take part in + the calculation, and thus do not propagate learning signal. This is + done because METR-LA denotes missing values using 0.0, which is + part of the data domain. + """ + # create_mask handles missing_value set to None. + mask = create_mask(targets, self.missing_value) + + num_terms = torch.sum(mask) + + loss = torch.abs(preds - targets) + return torch.sum(loss[mask]) / num_terms + + def training_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # noqa: ARG002 + targets = input_batch.y + out = self.model(input_batch) + + if self.scaler is not None: + raise NotImplementedError() + # out = self.scaler.inverse_transform(out) + + loss = self.masked_mae_loss(out, targets) + self.log("train_loss", loss) + + return loss diff --git a/src/gwnet/utils/__init__.py b/src/gwnet/utils/__init__.py new file mode 100644 index 0000000..0601def --- /dev/null +++ b/src/gwnet/utils/__init__.py @@ -0,0 +1,3 @@ +from .utils import create_mask + +__all__ = ["create_mask"] diff --git a/src/gwnet/utils/utils.py b/src/gwnet/utils/utils.py new file mode 100644 index 0000000..cf1923f --- /dev/null +++ b/src/gwnet/utils/utils.py @@ -0,0 +1,13 @@ +import torch + + +def create_mask( + matrix: torch.TensorType, null_value: None | float = 0.0 +) -> torch.Tensor: + if null_value is None: + return torch.ones_like(matrix).bool() + + # NOTE Casting to float in order to handle integer values. + return ~torch.isclose( + matrix.float() - torch.tensor(null_value).float(), torch.tensor(0.0) + ) diff --git a/tests/gwnet/test_model.py b/tests/gwnet/test_model.py new file mode 100644 index 0000000..c2edf3a --- /dev/null +++ b/tests/gwnet/test_model.py @@ -0,0 +1,26 @@ +import torch +from torch_geometric.data import Batch, Data +from torch_geometric.utils import dense_to_sparse + +from gwnet.model import GraphWavenet + + +class TestModel: + def make_datapoint(self): + dpoint = Data() + dpoint.x = torch.randn(207, 2, 12) + index, weights = dense_to_sparse( + torch.nn.functional.relu(torch.randn(207, 207) - 0.3) + ) + dpoint.edge_index = index + dpoint.edge_attr = weights + return dpoint + + def test_creation(self): + GraphWavenet() + + def test_forward(self): + batch = Batch.from_data_list([self.make_datapoint(), self.make_datapoint()]) + + model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207) + model(batch) diff --git a/tests/training/test_model.py b/tests/training/test_model.py new file mode 100644 index 0000000..882b5d0 --- /dev/null +++ b/tests/training/test_model.py @@ -0,0 +1,49 @@ +import lightning.pytorch as pl +import torch +from torch_geometric.data import Batch, Data +from torch_geometric.loader.dataloader import DataLoader +from torch_geometric.utils import dense_to_sparse + +from gwnet.datasets import METRLA +from gwnet.model import GraphWavenet +from gwnet.train import GWnetForecasting + + +class TestTrainingStep: + def make_datapoint(self): + dpoint = Data() + dpoint.x = torch.randn(207, 2, 12) + dpoint.y = torch.randn(207, 2, 12) + index, weights = dense_to_sparse( + torch.nn.functional.relu(torch.randn(207, 207) - 0.3) + ) + dpoint.edge_index = index + dpoint.edge_attr = weights + return dpoint + + def test_training_step(self): + args = {"lr": 0.01, "weight_decay": 0.0001} + + batch = Batch.from_data_list([self.make_datapoint(), self.make_datapoint()]) + model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207) + + lightning_module = GWnetForecasting(args, model, missing_value=0.0) + lightning_module.training_step(batch, 0) + + def test_trainer(self): + args = {"lr": 0.01, "weight_decay": 0.0001} + + model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207) + plmodule = GWnetForecasting(args, model, missing_value=0.0) + + trainer = pl.Trainer( + # accelerator='cpu', + max_steps=100, + gradient_clip_val=5.0, # TODO There was something about this in the code. + logger=False, + ) + + dataset = METRLA("./") + loader = DataLoader(dataset, batch_size=16, num_workers=0) + + trainer.fit(model=plmodule, train_dataloaders=loader)