diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 8bdfa58f..54d6d649 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -37,6 +37,8 @@ import torch import tqdm from lightning.pytorch import seed_everything +from hashlib import shake_256 +from urllib.parse import urlparse from . import __version__ from . import utils @@ -59,10 +61,9 @@ def __init__(self, *args, **kwargs) -> None: click.Option( ("-m", "--model"), help=""" - The model weights (.ckpt file). If not provided, Casanovo - will try to download the latest release. + Either the model weights (.ckpt file) or a URL pointing to the model weights + file. If not provided, Casanovo will try to download the latest release. """, - type=click.Path(exists=True, dir_okay=False), ), click.Option( ("-o", "--output"), @@ -354,9 +355,10 @@ def setup_model( seed_everything(seed=config["random_seed"], workers=True) # Download model weights if these were not specified (except when training). + cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False) if model is None and not is_train: try: - model = _get_model_weights() + model = _get_model_weights(cache_dir) except github.RateLimitExceededException: logger.error( "GitHub API rate limit exceeded while trying to download the " @@ -371,6 +373,17 @@ def setup_model( "model weights" ) from None + # Download model from URL if model is a valid url + is_url = _is_valid_url(model) + if (model is not None) and is_url: + model = _get_weights_from_url(model, Path(cache_dir)) + + if (model is not None) and (not is_url) and (not Path(model).is_file()): + raise ValueError( + f"{model} is not a valid URL or checkpoint file path, " + "--model argument must be a URL or checkpoint file path" + ) + # Log the active configuration. logger.info("Casanovo version %s", str(__version__)) logger.debug("model = %s", model) @@ -382,7 +395,76 @@ def setup_model( return config, model -def _get_model_weights() -> str: +def _get_weights_from_url( + file_url: Optional[str], + cache_dir: Path, +) -> str: + """ + Attempt to download weight file from URL if weights are not already + cached. Otherwise use cased weights. Downloaded weight files will be + cached. + + Parameters + ---------- + file_url : str + url pointing to model weights file + cache_dir : Path + model weights cache directory path + + Returns + ------- + str + path to cached weights file + """ + os.makedirs(cache_dir, exist_ok=True) + url_hash = shake_256(file_url.encode("utf-8")).hexdigest(20) + cache_file_name = url_hash + ".ckpt" + cache_file_path = cache_dir / cache_file_name + + if cache_file_path.is_file(): + logger.info(f"Model weights {file_url} retrieved from local cache") + return str(cache_file_path) + + logger.info(f"Model weights {file_url} not in local cache, downloading") + file_response = requests.get(file_url) + + if not file_response.ok: + logger.error(f"Failed to download weights from {file_url}") + logger.error( + f"Server Response: {file_response.status_code}: {file_response.reason}" + ) + raise ConnectionError(f"Failed to download weights file: {file_url}") + + logger.info("Model weights downloaded, writing to cache") + with open(cache_file_path, "wb") as cache_file: + cache_file.write(file_response.content) + + logger.info("Model weights cached") + return str(cache_file_path) + + +def _is_valid_url(file_url: str) -> bool: + """ + Determine whether file URL is a valid URL + + Parameters + ---------- + file_url : str + url to verify + + Return + ------ + is_url : bool + whether file_url is a valid url + """ + try: + result = urlparse(file_url) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + +def _get_model_weights(cache_dir: str) -> str: """ Use cached model weights or download them from GitHub. @@ -396,12 +478,16 @@ def _get_model_weights() -> str: Note that the GitHub API is limited to 60 requests from the same IP per hour. + Parameters + ---------- + cache_dir : str + model weights cache directory path + Returns ------- str The name of the model weights file. """ - cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False) os.makedirs(cache_dir, exist_ok=True) version = utils.split_version(__version__) version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0