Skip to content

Commit

Permalink
Update and bugfix
Browse files Browse the repository at this point in the history
* Add disable_gcn option for ablation and debugging
* Bugfix logging visualisation, previously showed x as true, where it should have been y.
  • Loading branch information
boykovdn committed Sep 6, 2024
1 parent 445ca84 commit 2f724bb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
6 changes: 4 additions & 2 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

def train():
args = {"lr": 0.001, "weight_decay": 0.0001}
device = "mps"
device = "cuda"
num_workers = 0 # NOTE Set to 0 for single thread debugging!

model = GraphWavenet(adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3)
model = GraphWavenet(
adaptive_embedding_dim=None, n_nodes=207, k_diffusion_hops=3, disable_gcn=True
)
plmodule = GWnetForecasting(args, model, missing_value=0.0)

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
Expand Down
2 changes: 1 addition & 1 deletion src/gwnet/callbacks/visualise_sequence_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def visualise_sequence(self, trainer: Trainer, pl_module: LightningModule) -> No
data.x[:, 0] = pl_module.scaler.inverse_transform(data.x[:, 0])

trainer.logger.experiment.add_scalar(
"Sequence true", data.x[node_idx, 0, steps_ahead], t
"Sequence true", data.y[node_idx, steps_ahead], t
)

trainer.logger.experiment.add_scalar(
Expand Down
16 changes: 15 additions & 1 deletion src/gwnet/model/gwnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
out_channels: int,
dilation: int = 1,
kernel_size: int = 2,
disable_gcn: bool = False,
):
r"""
Wraps the TCN and GCN modules.
Expand All @@ -146,9 +147,14 @@ def __init__(
args (dict): Contains parameters passed from the parent module, namely
flags showing which adjacency matrices to expect and initialise
parameters for.
disable_gcn (bool): If True, the GCN aggregation will not be computed,
so effectively the model will have no graph component.
"""
super().__init__()

self._disable_gcn = disable_gcn

self.tcn = GatedTCN(
in_channels, interm_channels, kernel_size=kernel_size, dilation=dilation
)
Expand All @@ -167,11 +173,17 @@ def forward(self, x: Data, cached_adj: dict[str, torch.Tensor]) -> Data:
to each other, and the node index matters.
Args:
x: The batched Data object.
x (Data): The batched Data object.
cached_adj (dict): Adjacency matrices used in the GCN calculation.
"""
# TCN works on the features alone, and handles the (N,C,L) shape
# internally.
tcn_out = self.tcn(x)

if self._disable_gcn:
return tcn_out

return self.gcn(tcn_out, cached_adj)


Expand All @@ -191,6 +203,7 @@ def __init__(
n_nodes: int | None = None,
forward_diffusion: bool = True,
backward_diffusion: bool = True,
disable_gcn: bool = False,
):
r"""
Initialise the GWnet model.
Expand Down Expand Up @@ -257,6 +270,7 @@ def __init__(
dilation_channels,
residual_channels,
dilation=dilation,
disable_gcn=disable_gcn,
),
)
)
Expand Down

0 comments on commit 2f724bb

Please sign in to comment.