Skip to content

Commit

Permalink
Fix data shuffling
Browse files Browse the repository at this point in the history
* Performance now looks better, can use gcn information
  • Loading branch information
boykovdn committed Sep 7, 2024
1 parent e1963ca commit aa84686
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions experiments/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ def train():
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = pl.Trainer(
accelerator=device,
max_steps=10000,
max_steps=100000,
limit_val_batches=1, # TODO Debugging
gradient_clip_val=5.0, # TODO There was something about this in the code.
logger=tb_logger,
val_check_interval=500,
val_check_interval=1000,
callbacks=[VisualiseSequencePrediction(torch.device(device))],
)

dataset = METRLA("./")
scaler = TrafficStandardScaler.from_dataset(dataset, n_samples=30000)
plmodule.scaler = scaler
## TODO Parametrise.
loader = DataLoader(dataset, batch_size=64, num_workers=num_workers)
loader = DataLoader(dataset, batch_size=32, num_workers=num_workers, shuffle=True)

## TODO Change the validation loader to NOT the training loader!
# This is for debugging the visualisation atm.
Expand Down

0 comments on commit aa84686

Please sign in to comment.