Skip to content

Commit

Permalink
[WIP] Fix node_embeddings grad
Browse files Browse the repository at this point in the history
* Previously, no grad propagated to the node_embeddings
* Model still doesn't seem to train properly, loss is too noisy.
  • Loading branch information
boykovdn committed Aug 31, 2024
1 parent 01dbcf3 commit 3bfddca
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
4 changes: 2 additions & 2 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def train():
device = "gpu"
num_workers = 0 # NOTE Set to 0 for single thread debugging!

model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1)
model = GraphWavenet(adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3)
plmodule = GWnetForecasting(args, model, missing_value=0.0)

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
Expand All @@ -28,7 +28,7 @@ def train():
scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000)
plmodule.scaler = scaler
## TODO Parametrise.
loader = DataLoader(dataset, batch_size=32, num_workers=num_workers)
loader = DataLoader(dataset, batch_size=8, num_workers=num_workers)

trainer.fit(model=plmodule, train_dataloaders=loader)

Expand Down
20 changes: 14 additions & 6 deletions src/gwnet/model/gwnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj

Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(
"""
super().__init__()

self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias)
# self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias)

# Create one GCN per hop and diffusion direction.
gcn_dict = {}
Expand All @@ -97,6 +98,11 @@ def __init__(
in_channels, out_channels, bias=False, node_dim=0
)

if bias:
self.bias = Parameter(torch.randn(out_channels))
else:
self.register_parameter("bias", None)

self.gcns = torch.nn.ModuleDict(gcn_dict)

def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data:
Expand All @@ -115,6 +121,11 @@ def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data:
x.x.transpose(1, 2), adj_index, cached_params["adj_weights"][adj_name]
).transpose(1, 2)

x.x = out_sum

if self.bias is not None:
x.x += self.bias.view(1, -1, 1)

return x


Expand Down Expand Up @@ -161,7 +172,6 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data:
# TCN works on the features alone, and handles the (N,C,L) shape
# internally.
tcn_out = self.tcn(x)

return self.gcn(tcn_out, cached_adj)


Expand Down Expand Up @@ -208,9 +218,6 @@ def __init__(
self.register_parameter("node_embeddings", None)
adp = False

# TODO Why is node_embeddings not registering?
# import pdb; pdb.set_trace()

# This model accepts the entire road network, hence can cache
# the diffusion adjacency matrices, doesn't have to take their
# powers on every forward pass.
Expand Down Expand Up @@ -383,7 +390,8 @@ def forward(self, batch: Data) -> torch.Tensor:
)
)

self._update_adp_adj(batch.batch.max() + 1, self.global_elements["k_hops"])
if self.node_embeddings is not None:
self._update_adp_adj(batch.batch.max() + 1, self.global_elements["k_hops"])

# x_dict = batch.x_dict
# edge_index_dict = batch.edge_index_dict
Expand Down
4 changes: 4 additions & 0 deletions src/gwnet/model/layer/tempo_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__( # type: ignore[no-untyped-def]

self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias)

def message(self, x_j: torch.Tensor, edge_weight: torch.Tensor) -> torch.Tensor:
# Message is the original input times the edge weight.
return x_j * edge_weight.view(-1, 1, 1)

def forward(
self,
x: torch.Tensor,
Expand Down

0 comments on commit 3bfddca

Please sign in to comment.