diff --git a/config/onboard_classifier.yaml b/config/onboard_classifier.yaml index 8b6c875..e0434d8 100644 --- a/config/onboard_classifier.yaml +++ b/config/onboard_classifier.yaml @@ -3,8 +3,8 @@ data_loader: events_path: events_201402.json image_size: 512 pol: all - batch_size: 4 - num_workers: 2 + batch_size: 1 + num_workers: 1 train_pct: 0.85 valid_pct: 0.1 @@ -16,7 +16,7 @@ train: ckpt_dir: /mnt/training_classifier//checkpoints/ out_dir: /mnt/training_classifier// - wandb_mode: offline # online or offline or disabled + wandb_mode: online # online or offline or disabled resume_from_checkpoint: False ckpt_path: last # path or "last" diff --git a/src/icarus/onboard/classifier/train.py b/src/icarus/onboard/classifier/train.py index cafdffe..27433aa 100644 --- a/src/icarus/onboard/classifier/train.py +++ b/src/icarus/onboard/classifier/train.py @@ -129,7 +129,14 @@ def validation_step(self, batch, batch_idx): self.log( "val/loss", loss, prog_bar=True, sync_dist=True, batch_size=self.batch_size ) - # ToDo: add prediction visualisation + if batch_idx == self.validation_batch_index: + x_vis = (x[0, 0] - x[0, 0].min()) / (x[0, 0].max() - x[0, 0].min()) + table = wandb.Table( + columns=["Observation_image", "Predicted_label", "Ground truth label"] + ) + img = wandb.Image(x_vis.cpu().numpy()) + table.add_data(img, torch.nn.functional.sigmoid(y[0]), y_target[0]) + wandb.log({"Table": table}) return loss @@ -222,7 +229,9 @@ def get_config(config_path): entity="ssa_live_twin", config=config, mode=config["train"]["wandb_mode"], - name=config["train"]["run_id"], + name=config["train"][ + "run_id" + ], # ToDo: This doesn't work for more than 1 batch size ) wandb_logger = WandbLogger()