Skip to content

Commit

Permalink
added wandb visualisations
Browse files Browse the repository at this point in the history
  • Loading branch information
hlamdouar committed Aug 15, 2023
1 parent d21f05e commit f782781
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
6 changes: 3 additions & 3 deletions config/onboard_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data_loader:
events_path: events_201402.json
image_size: 512
pol: all
batch_size: 4
num_workers: 2
batch_size: 1
num_workers: 1
train_pct: 0.85
valid_pct: 0.1

Expand All @@ -16,7 +16,7 @@ train:
ckpt_dir: /mnt/training_classifier/<run_id>/checkpoints/
out_dir: /mnt/training_classifier/<run_id>/

wandb_mode: offline # online or offline or disabled
wandb_mode: online # online or offline or disabled

resume_from_checkpoint: False
ckpt_path: last # path or "last"
Expand Down
13 changes: 11 additions & 2 deletions src/icarus/onboard/classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,14 @@ def validation_step(self, batch, batch_idx):
self.log(
"val/loss", loss, prog_bar=True, sync_dist=True, batch_size=self.batch_size
)
# ToDo: add prediction visualisation
if batch_idx == self.validation_batch_index:
x_vis = (x[0, 0] - x[0, 0].min()) / (x[0, 0].max() - x[0, 0].min())
table = wandb.Table(
columns=["Observation_image", "Predicted_label", "Ground truth label"]
)
img = wandb.Image(x_vis.cpu().numpy())
table.add_data(img, torch.nn.functional.sigmoid(y[0]), y_target[0])
wandb.log({"Table": table})

return loss

Expand Down Expand Up @@ -222,7 +229,9 @@ def get_config(config_path):
entity="ssa_live_twin",
config=config,
mode=config["train"]["wandb_mode"],
name=config["train"]["run_id"],
name=config["train"][
"run_id"
], # ToDo: This doesn't work for more than 1 batch size
)
wandb_logger = WandbLogger()

Expand Down

0 comments on commit f782781

Please sign in to comment.