diff --git a/experiments/reproduction.py b/experiments/reproduction.py index 37c6764..414f36b 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -16,7 +16,7 @@ def train(): num_workers = 0 # NOTE Set to 0 for single thread debugging! model = GraphWavenet( - adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=False + adaptive_embedding_dim=64, n_nodes=207, k_diffusion_hops=3, disable_gcn=False ) plmodule = GWnetForecasting(args, model, missing_value=0.0) @@ -35,7 +35,7 @@ def train(): scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000) plmodule.scaler = scaler ## TODO Parametrise. - loader = DataLoader(dataset, batch_size=32, num_workers=num_workers, shuffle=True) + loader = DataLoader(dataset, batch_size=8, num_workers=num_workers, shuffle=True) ## TODO Change the validation loader to NOT the training loader! # This is for debugging the visualisation atm. diff --git a/quarto/notes.qmd b/quarto/notes.qmd index c6a08f8..e374448 100644 --- a/quarto/notes.qmd +++ b/quarto/notes.qmd @@ -386,6 +386,21 @@ Yet to find out whether the training does anything useful. Next step might be to plot the adaptive adjacency as the training progresses. Also profile the training to see whether I get any bottlenecks as I used to before. +I fixed a NaN bug where having all 0 inputs breaks. +Forgot to shuffle the dataset, which completely throws it off. +And added the missing skip connections to the residual modules. +Now it looks like it does something useful. +However, adding the adaptive adjacency massively increases the GPU cost. +I will likely need to optimise it, I can only train up to 8 batch size on 12Gb of VRAM at the moment. +Not sure off the top of my head where the massive memory footprint comes from, maybe I didn't make the matrix sparse? + +- [ ] Write the validation and test steps. +- [ ] Parameterise the reproduction script. +- [ ] Debug/optimise the adaptive adjacency. +- [ ] Visualise the adaptive adjacency evolution using nx in tensorboard. +- [ ] Add PEMS-BAY dataset. + + ### References ::: {#refs} diff --git a/src/gwnet/callbacks/visualise_sequence_pred.py b/src/gwnet/callbacks/visualise_sequence_pred.py index ec9ca71..f0edd99 100644 --- a/src/gwnet/callbacks/visualise_sequence_pred.py +++ b/src/gwnet/callbacks/visualise_sequence_pred.py @@ -27,7 +27,12 @@ def visualise_sequence(self, trainer: Trainer, pl_module: LightningModule) -> No model = pl_module.model for t in range(offset, offset + seq_len): - data = dset[t].to(self.device) + data = dset[t] + # Add dummy batch array of all 0s. This corresponds to all nodes + # being part of the same batch. + data.batch = torch.zeros(data.x.shape[0]).int() + data.ptr = torch.zeros(1).int() + data = data.to(self.device) # Two ugly conditionals make sure that the model works on # scaled data, and the output is scaled back into mph.