From e1963ca6d1efdc154dcedbdd3dffc8800e38eca7 Mon Sep 17 00:00:00 2001 From: Boyko Vodenicharski Date: Fri, 6 Sep 2024 21:38:49 +0100 Subject: [PATCH] Bugfixes and model updates * Add skip connection to residual block * Fix bug with NaNs in loss * Add linear layer to diffusion (0th order) that also contains the bias --- experiments/reproduction.py | 4 ++-- src/gwnet/model/gwnet.py | 28 ++++++++++++++-------------- src/gwnet/train/gwnet.py | 12 +++++++++++- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/experiments/reproduction.py b/experiments/reproduction.py index a805e29..9870e86 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -16,7 +16,7 @@ def train(): num_workers = 0 # NOTE Set to 0 for single thread debugging! model = GraphWavenet( - adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=True + adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=False ) plmodule = GWnetForecasting(args, model, missing_value=0.0) @@ -35,7 +35,7 @@ def train(): scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000) plmodule.scaler = scaler ## TODO Parametrise. - loader = DataLoader(dataset, batch_size=16, num_workers=num_workers) + loader = DataLoader(dataset, batch_size=64, num_workers=num_workers) ## TODO Change the validation loader to NOT the training loader! # This is for debugging the visualisation atm. diff --git a/src/gwnet/model/gwnet.py b/src/gwnet/model/gwnet.py index 794f073..74b044f 100644 --- a/src/gwnet/model/gwnet.py +++ b/src/gwnet/model/gwnet.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter from torch_geometric.data import Data from torch_geometric.utils import to_dense_adj @@ -76,7 +75,7 @@ def __init__( """ super().__init__() - # self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias) # Create one GCN per hop and diffusion direction. gcn_dict = {} @@ -98,11 +97,6 @@ def __init__( in_channels, out_channels, bias=False, node_dim=0 ) - if bias: - self.bias = Parameter(torch.randn(out_channels)) - else: - self.register_parameter("bias", None) - self.gcns = torch.nn.ModuleDict(gcn_dict) def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: @@ -110,7 +104,7 @@ def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: Args: x (Data): Input graph. """ - out_sum = 0 + out_sum = self.id_linear(x.x.transpose(1, 2)) for adj_name, adj_index in cached_params["adj_indices"].items(): gcn_ = self.gcns[adj_name] # NOTE Some trickery here. In TempGCN the dense layer works on the @@ -119,12 +113,9 @@ def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: # transpose C_in, L then transpose back to C_out, L. out_sum += gcn_( x.x.transpose(1, 2), adj_index, cached_params["adj_weights"][adj_name] - ).transpose(1, 2) - - x.x = out_sum + ) - if self.bias is not None: - x.x += self.bias.view(1, -1, 1) + x.x = out_sum.transpose(1, 2) return x @@ -162,6 +153,8 @@ def __init__( args, interm_channels, out_channels ) # TODO # interm -> out channels, diffusion_hops + self.skip_linear = torch.nn.Linear(in_channels, out_channels) + def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data: r""" Apply the gated TCN followed by GCN. @@ -182,9 +175,16 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data: tcn_out = self.tcn(x) if self._disable_gcn: + tcn_out.x = tcn_out.x + self.skip_linear(x.x.transpose(1, 2)).transpose( + 1, 2 + ) return tcn_out - return self.gcn(tcn_out, cached_adj) + residual = self.gcn(tcn_out, cached_adj) + + residual.x = residual.x + self.skip_linear(x.x.transpose(1, 2)).transpose(1, 2) + + return residual class GraphWavenet(torch.nn.Module): diff --git a/src/gwnet/train/gwnet.py b/src/gwnet/train/gwnet.py index f6abc57..329adac 100644 --- a/src/gwnet/train/gwnet.py +++ b/src/gwnet/train/gwnet.py @@ -61,6 +61,13 @@ def masked_mae_loss( num_terms = torch.sum(mask) loss = torch.abs(preds - targets) + + if num_terms == 0: + # Occasionally, all values are missing or 0. In this case, + # return a loss of 0 and gradient function, which can be + # done by selecting no values (mask all False) and summing. + return torch.sum(loss[mask]) + return torch.sum(loss[mask]) / num_terms def validation_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # noqa: ARG002 @@ -78,6 +85,9 @@ def training_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # n out = self.scaler.inverse_transform(out) loss = self.masked_mae_loss(out, targets) - self.log("train_loss", loss) + + if loss != 0.0: + # A loss of 0.0 means all values are missing or 0. This pollutes the log. + self.log("train_loss", loss) return loss