diff --git a/experiments/reproduction.py b/experiments/reproduction.py index dcb26ad..278591e 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -13,7 +13,7 @@ def train(): device = "gpu" num_workers = 0 # NOTE Set to 0 for single thread debugging! - model = GraphWavenet(adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=1) + model = GraphWavenet(adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3) plmodule = GWnetForecasting(args, model, missing_value=0.0) tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/") @@ -28,7 +28,7 @@ def train(): scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000) plmodule.scaler = scaler ## TODO Parametrise. - loader = DataLoader(dataset, batch_size=32, num_workers=num_workers) + loader = DataLoader(dataset, batch_size=8, num_workers=num_workers) trainer.fit(model=plmodule, train_dataloaders=loader) diff --git a/src/gwnet/model/gwnet.py b/src/gwnet/model/gwnet.py index f08fb45..a67e2bd 100644 --- a/src/gwnet/model/gwnet.py +++ b/src/gwnet/model/gwnet.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from torch.nn.parameter import Parameter from torch_geometric.data import Data from torch_geometric.utils import to_dense_adj @@ -75,7 +76,7 @@ def __init__( """ super().__init__() - self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + # self.id_linear = torch.nn.Linear(in_channels, out_channels, bias=bias) # Create one GCN per hop and diffusion direction. gcn_dict = {} @@ -97,6 +98,11 @@ def __init__( in_channels, out_channels, bias=False, node_dim=0 ) + if bias: + self.bias = Parameter(torch.randn(out_channels)) + else: + self.register_parameter("bias", None) + self.gcns = torch.nn.ModuleDict(gcn_dict) def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: @@ -115,6 +121,11 @@ def forward(self, x: Data, cached_params: dict[str, torch.Tensor]) -> Data: x.x.transpose(1, 2), adj_index, cached_params["adj_weights"][adj_name] ).transpose(1, 2) + x.x = out_sum + + if self.bias is not None: + x.x += self.bias.view(1, -1, 1) + return x @@ -161,7 +172,6 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data: # TCN works on the features alone, and handles the (N,C,L) shape # internally. tcn_out = self.tcn(x) - return self.gcn(tcn_out, cached_adj) @@ -208,9 +218,6 @@ def __init__( self.register_parameter("node_embeddings", None) adp = False - # TODO Why is node_embeddings not registering? - # import pdb; pdb.set_trace() - # This model accepts the entire road network, hence can cache # the diffusion adjacency matrices, doesn't have to take their # powers on every forward pass. @@ -383,7 +390,8 @@ def forward(self, batch: Data) -> torch.Tensor: ) ) - self._update_adp_adj(batch.batch.max() + 1, self.global_elements["k_hops"]) + if self.node_embeddings is not None: + self._update_adp_adj(batch.batch.max() + 1, self.global_elements["k_hops"]) # x_dict = batch.x_dict # edge_index_dict = batch.edge_index_dict diff --git a/src/gwnet/model/layer/tempo_gcn.py b/src/gwnet/model/layer/tempo_gcn.py index ceb69fe..e2019dd 100644 --- a/src/gwnet/model/layer/tempo_gcn.py +++ b/src/gwnet/model/layer/tempo_gcn.py @@ -41,6 +41,10 @@ def __init__( # type: ignore[no-untyped-def] self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + def message(self, x_j: torch.Tensor, edge_weight: torch.Tensor) -> torch.Tensor: + # Message is the original input times the edge weight. + return x_j * edge_weight.view(-1, 1, 1) + def forward( self, x: torch.Tensor,