diff --git a/experiments/reproduction.py b/experiments/reproduction.py index 7055adb..a805e29 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -12,10 +12,12 @@ def train(): args = {"lr": 0.001, "weight_decay": 0.0001} - device = "mps" + device = "cuda" num_workers = 0 # NOTE Set to 0 for single thread debugging! - model = GraphWavenet(adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3) + model = GraphWavenet( + adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=True + ) plmodule = GWnetForecasting(args, model, missing_value=0.0) tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/") diff --git a/src/gwnet/callbacks/visualise_sequence_pred.py b/src/gwnet/callbacks/visualise_sequence_pred.py index 36fb2ed..ec9ca71 100644 --- a/src/gwnet/callbacks/visualise_sequence_pred.py +++ b/src/gwnet/callbacks/visualise_sequence_pred.py @@ -47,7 +47,7 @@ def visualise_sequence(self, trainer: Trainer, pl_module: LightningModule) -> No data.x[:, 0] = pl_module.scaler.inverse_transform(data.x[:, 0]) trainer.logger.experiment.add_scalar( - "Sequence true", data.x[node_idx, 0, steps_ahead], t + "Sequence true", data.y[node_idx, steps_ahead], t ) trainer.logger.experiment.add_scalar( diff --git a/src/gwnet/model/gwnet.py b/src/gwnet/model/gwnet.py index a67e2bd..794f073 100644 --- a/src/gwnet/model/gwnet.py +++ b/src/gwnet/model/gwnet.py @@ -138,6 +138,7 @@ def __init__( out_channels: int, dilation: int = 1, kernel_size: int = 2, + disable_gcn: bool = False, ): r""" Wraps the TCN and GCN modules. @@ -146,9 +147,14 @@ def __init__( args (dict): Contains parameters passed from the parent module, namely flags showing which adjacency matrices to expect and initialise parameters for. + + disable_gcn (bool): If True, the GCN aggregation will not be computed, + so effectively the model will have no graph component. """ super().__init__() + self._disable_gcn = disable_gcn + self.tcn = GatedTCN( in_channels, interm_channels, kernel_size=kernel_size, dilation=dilation ) @@ -167,11 +173,17 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data: to each other, and the node index matters. Args: - x: The batched Data object. + x (Data): The batched Data object. + + cached_adj (dict): Adjacency matrices used in the GCN calculation. """ # TCN works on the features alone, and handles the (N,C,L) shape # internally. tcn_out = self.tcn(x) + + if self._disable_gcn: + return tcn_out + return self.gcn(tcn_out, cached_adj) @@ -191,6 +203,7 @@ def __init__( n_nodes: int | None = None, forward_diffusion: bool = True, backward_diffusion: bool = True, + disable_gcn: bool = False, ): r""" Initialise the GWnet model. @@ -257,6 +270,7 @@ def __init__( dilation_channels, residual_channels, dilation=dilation, + disable_gcn=disable_gcn, ), ) )