diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index efdb41d..2b52f23 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -2,6 +2,8 @@ import importlib.util import inspect +import logging +import os import sys from collections.abc import Callable from functools import partial @@ -43,6 +45,9 @@ ) from cloudcasting.utils import numpy_validation_collate_fn +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + # defined in manchester prize technical document WANDB_ENTITY = "manchester_prize" VIDEO_SAMPLE_DATES = [ @@ -243,6 +248,9 @@ def get_pix_function( # we probably want to accumulate metrics here instead of taking the mean of means! loop_steps = len(valid_dataloader) if batch_limit is None else batch_limit + info_str = f"Validating model on {loop_steps} batches..." + logger.info(info_str) + for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps): y_hat = model(X) @@ -355,8 +363,16 @@ def validate( msg = f"Failed to run the model forward due to the following error: {err}" raise ValueError(msg) from err - # Login to wandb - wandb.login() + # grab api key from environment variable + wandb_api_key = os.environ.get("WANDB_API_KEY") + + if not wandb_api_key: + msg = "WANDB_API_KEY environment variable not set. Attempting interactive login..." + logger.warning(msg) + wandb.login() + else: + logger.info("API key found. Logging in to WandB...") + wandb.login(key=wandb_api_key) # Set up the validation dataset valid_dataset = ValidationSatelliteDataset(