Skip to content

Commit

Permalink
Bugfixes and model updates
Browse files Browse the repository at this point in the history
* Add skip connection to residual block
* Fix bug with NaNs in loss
* Add linear layer to diffusion (0th order) that also contains the bias
  • Loading branch information
boykovdn committed Sep 6, 2024
1 parent 2f724bb commit e1963ca
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
4 changes: 2 additions & 2 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def train():
num_workers = 0 # NOTE Set to 0 for single thread debugging!

model = GraphWavenet(
adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=True
adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=False
)
plmodule = GWnetForecasting(args, model, missing_value=0.0)

Expand All @@ -35,7 +35,7 @@ def train():
scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000)
plmodule.scaler = scaler
## TODO Parametrise.
loader = DataLoader(dataset, batch_size=16, num_workers=num_workers)
loader = DataLoader(dataset, batch_size=64, num_workers=num_workers)

## TODO Change the validation loader to NOT the training loader!
# This is for debugging the visualisation atm.
Expand Down
28 changes: 14 additions & 14 deletions src/gwnet/model/gwnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj

Expand Down Expand Up @@ -76,7 +75,7 @@ def __init__(
"""
super().__init__()

# self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias)
self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias)

# Create one GCN per hop and diffusion direction.
gcn_dict = {}
Expand All @@ -98,19 +97,14 @@ def __init__(
in_channels, out_channels, bias=False, node_dim=0
)

if bias:
self.bias = Parameter(torch.randn(out_channels))
else:
self.register_parameter("bias", None)

self.gcns = torch.nn.ModuleDict(gcn_dict)

def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data:
r"""
Args:
x (Data): Input graph.
"""
out_sum = 0
out_sum = self.id_linear(x.x.transpose(1, 2))
for adj_name, adj_index in cached_params["adj_indices"].items():
gcn_ = self.gcns[adj_name]
# NOTE Some trickery here. In TempGCN the dense layer works on the
Expand All @@ -119,12 +113,9 @@ def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data:
# transpose C_in, L then transpose back to C_out, L.
out_sum += gcn_(
x.x.transpose(1, 2), adj_index, cached_params["adj_weights"][adj_name]
).transpose(1, 2)

x.x = out_sum
)

if self.bias is not None:
x.x += self.bias.view(1, -1, 1)
x.x = out_sum.transpose(1, 2)

return x

Expand Down Expand Up @@ -162,6 +153,8 @@ def __init__(
args, interm_channels, out_channels
) # TODO # interm -> out channels, diffusion_hops

self.skip_linear = torch.nn.Linear(in_channels, out_channels)

def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data:
r"""
Apply the gated TCN followed by GCN.
Expand All @@ -182,9 +175,16 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data:
tcn_out = self.tcn(x)

if self._disable_gcn:
tcn_out.x = tcn_out.x + self.skip_linear(x.x.transpose(1, 2)).transpose(
1, 2
)
return tcn_out

return self.gcn(tcn_out, cached_adj)
residual = self.gcn(tcn_out, cached_adj)

residual.x = residual.x + self.skip_linear(x.x.transpose(1, 2)).transpose(1, 2)

return residual


class GraphWavenet(torch.nn.Module):
Expand Down
12 changes: 11 additions & 1 deletion src/gwnet/train/gwnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def masked_mae_loss(
num_terms = torch.sum(mask)

loss = torch.abs(preds - targets)

if num_terms == 0:
# Occasionally, all values are missing or 0. In this case,
# return a loss of 0 and gradient function, which can be
# done by selecting no values (mask all False) and summing.
return torch.sum(loss[mask])

return torch.sum(loss[mask]) / num_terms

def validation_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # noqa: ARG002
Expand All @@ -78,6 +85,9 @@ def training_step(self, input_batch: Data, batch_idx: int) -> torch.Tensor: # n
out = self.scaler.inverse_transform(out)

loss = self.masked_mae_loss(out, targets)
self.log("train_loss", loss)

if loss != 0.0:
# A loss of 0.0 means all values are missing or 0. This pollutes the log.
self.log("train_loss", loss)

return loss

0 comments on commit e1963ca

Please sign in to comment.