From aa846863d0ffa6bc33c4dc3f17510c8ae4c42b0f Mon Sep 17 00:00:00 2001 From: Boyko Vodenicharski Date: Sat, 7 Sep 2024 13:06:25 +0100 Subject: [PATCH] Fix data shuffling * Performance now looks better, can use gcn information --- experiments/reproduction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experiments/reproduction.py b/experiments/reproduction.py index 9870e86..37c6764 100644 --- a/experiments/reproduction.py +++ b/experiments/reproduction.py @@ -23,11 +23,11 @@ def train(): tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/") trainer = pl.Trainer( accelerator=device, - max_steps=10000, + max_steps=100000, limit_val_batches=1, # TODO Debugging gradient_clip_val=5.0, # TODO There was something about this in the code. logger=tb_logger, - val_check_interval=500, + val_check_interval=1000, callbacks=[VisualiseSequencePrediction(torch.device(device))], ) @@ -35,7 +35,7 @@ def train(): scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000) plmodule.scaler = scaler ## TODO Parametrise. - loader = DataLoader(dataset, batch_size=64, num_workers=num_workers) + loader = DataLoader(dataset, batch_size=32, num_workers=num_workers, shuffle=True) ## TODO Change the validation loader to NOT the training loader! # This is for debugging the visualisation atm.