Skip to content

Commit

Permalink
login to wandb via env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
phinate committed Sep 26, 2024
1 parent d84dde0 commit 930fa32
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import importlib.util
import inspect
import logging
import os
import sys
from collections.abc import Callable
from functools import partial
Expand Down Expand Up @@ -43,6 +45,9 @@
)
from cloudcasting.utils import numpy_validation_collate_fn

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# defined in manchester prize technical document
WANDB_ENTITY = "manchester_prize"
VIDEO_SAMPLE_DATES = [
Expand Down Expand Up @@ -243,6 +248,9 @@ def get_pix_function(
# we probably want to accumulate metrics here instead of taking the mean of means!
loop_steps = len(valid_dataloader) if batch_limit is None else batch_limit

info_str = f"Validating model on {loop_steps} batches..."
logger.info(info_str)

for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps):
y_hat = model(X)

Expand Down Expand Up @@ -355,8 +363,16 @@ def validate(
msg = f"Failed to run the model forward due to the following error: {err}"
raise ValueError(msg) from err

# Login to wandb
wandb.login()
# grab api key from environment variable
wandb_api_key = os.environ.get("WANDB_API_KEY")

if not wandb_api_key:
msg = "WANDB_API_KEY environment variable not set. Attempting interactive login..."
logger.warning(msg)
wandb.login()
else:
logger.info("API key found. Logging in to WandB...")
wandb.login(key=wandb_api_key)

# Set up the validation dataset
valid_dataset = ValidationSatelliteDataset(
Expand Down

0 comments on commit 930fa32

Please sign in to comment.