Skip to content

Commit

Permalink
downlaod weights from URL
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Jun 27, 2024
1 parent 70ea9fc commit 6431cbf
Showing 1 changed file with 92 additions and 6 deletions.
98 changes: 92 additions & 6 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import torch
import tqdm
from lightning.pytorch import seed_everything
from hashlib import shake_256
from urllib.parse import urlparse

from . import __version__
from . import utils
Expand All @@ -59,10 +61,9 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
The model weights (.ckpt file). If not provided, Casanovo
will try to download the latest release.
Either the model weights (.ckpt file) or a URL pointing to the model weights
file. If not provided, Casanovo will try to download the latest release.
""",
type=click.Path(exists=True, dir_okay=False),
),
click.Option(
("-o", "--output"),
Expand Down Expand Up @@ -354,9 +355,10 @@ def setup_model(
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
if model is None and not is_train:
try:
model = _get_model_weights()
model = _get_model_weights(cache_dir)

Check warning on line 361 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L361

Added line #L361 was not covered by tests
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
Expand All @@ -371,6 +373,17 @@ def setup_model(
"model weights"
) from None

# Download model from URL if model is a valid url
is_url = _is_valid_url(model)
if (model is not None) and is_url:
model = _get_weights_from_url(model, Path(cache_dir))

Check warning on line 379 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L379

Added line #L379 was not covered by tests

if (model is not None) and (not is_url) and (not Path(model).is_file()):
raise ValueError(

Check warning on line 382 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L382

Added line #L382 was not covered by tests
f"{model} is not a valid URL or checkpoint file path, "
"--model argument must be a URL or checkpoint file path"
)

# Log the active configuration.
logger.info("Casanovo version %s", str(__version__))
logger.debug("model = %s", model)
Expand All @@ -382,7 +395,76 @@ def setup_model(
return config, model


def _get_model_weights() -> str:
def _get_weights_from_url(
file_url: Optional[str],
cache_dir: Path,
) -> str:
"""
Attempt to download weight file from URL if weights are not already
cached. Otherwise use cased weights. Downloaded weight files will be
cached.
Parameters
----------
file_url : str
url pointing to model weights file
cache_dir : Path
model weights cache directory path
Returns
-------
str
path to cached weights file
"""
os.makedirs(cache_dir, exist_ok=True)
url_hash = shake_256(file_url.encode("utf-8")).hexdigest(20)
cache_file_name = url_hash + ".ckpt"
cache_file_path = cache_dir / cache_file_name

Check warning on line 422 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L419-L422

Added lines #L419 - L422 were not covered by tests

if cache_file_path.is_file():
logger.info(f"Model weights {file_url} retrieved from local cache")
return str(cache_file_path)

Check warning on line 426 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L424-L426

Added lines #L424 - L426 were not covered by tests

logger.info(f"Model weights {file_url} not in local cache, downloading")
file_response = requests.get(file_url)

Check warning on line 429 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L428-L429

Added lines #L428 - L429 were not covered by tests

if not file_response.ok:
logger.error(f"Failed to download weights from {file_url}")
logger.error(

Check warning on line 433 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L431-L433

Added lines #L431 - L433 were not covered by tests
f"Server Response: {file_response.status_code}: {file_response.reason}"
)
raise ConnectionError(f"Failed to download weights file: {file_url}")

Check warning on line 436 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L436

Added line #L436 was not covered by tests

logger.info("Model weights downloaded, writing to cache")
with open(cache_file_path, "wb") as cache_file:
cache_file.write(file_response.content)

Check warning on line 440 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L438-L440

Added lines #L438 - L440 were not covered by tests

logger.info("Model weights cached")
return str(cache_file_path)

Check warning on line 443 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L442-L443

Added lines #L442 - L443 were not covered by tests


def _is_valid_url(file_url: str) -> bool:
"""
Determine whether file URL is a valid URL
Parameters
----------
file_url : str
url to verify
Return
------
is_url : bool
whether file_url is a valid url
"""
try:
result = urlparse(file_url)
return all([result.scheme, result.netloc])
except ValueError:
return False

Check warning on line 464 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L463-L464

Added lines #L463 - L464 were not covered by tests


def _get_model_weights(cache_dir: str) -> str:
"""
Use cached model weights or download them from GitHub.
Expand All @@ -396,12 +478,16 @@ def _get_model_weights() -> str:
Note that the GitHub API is limited to 60 requests from the same IP per
hour.
Parameters
----------
cache_dir : str
model weights cache directory path
Returns
-------
str
The name of the model weights file.
"""
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
os.makedirs(cache_dir, exist_ok=True)
version = utils.split_version(__version__)
version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0
Expand Down

0 comments on commit 6431cbf

Please sign in to comment.