Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download weight file from URL #349

Merged
merged 19 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Added

- During training, model checkpoints will now be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run.
- During training, model checkpoints will be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run.
- Besides as a local file, model weights can be specified from a URL. Upon initial download, the weights file is cached for future re-use.

### Fixed

- Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.
- Precursor charges are exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.

## [4.2.1] - 2024-06-25

Expand Down
200 changes: 166 additions & 34 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import datetime
import functools
import hashlib
import logging
import os
import re
import shutil
import sys
import time
import urllib.parse
import warnings
from pathlib import Path
from typing import Optional, Tuple
Expand Down Expand Up @@ -60,10 +62,9 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
The model weights (.ckpt file). If not provided, Casanovo
will try to download the latest release.
Either the model weights (.ckpt file) or a URL pointing to the model weights
file. If not provided, Casanovo will try to download the latest release automatically.
""",
type=click.Path(exists=True, dir_okay=False),
),
click.Option(
("-o", "--output"),
Expand Down Expand Up @@ -365,22 +366,34 @@ def setup_model(
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
if model is None and not is_train:
try:
model = _get_model_weights()
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
"model weights. Please download compatible model weights "
"manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify these "
"explicitly using the `--model` parameter when running "
"Casanovo."
cache_dir = Path(appdirs.user_cache_dir("casanovo", False, opinion=False))
if model is None:
if not is_train:
try:
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
model = _get_model_weights(cache_dir)
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
"model weights. Please download compatible model weights "
"manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify these "
"explicitly using the `--model` parameter when running "
"Casanovo."
)
raise PermissionError(
"GitHub API rate limit exceeded while trying to download the "
"model weights"
) from None
else:
if _is_valid_url(model):
model = _get_weights_from_url(model, cache_dir)
elif not Path(model).is_file():
error_msg = (
f"{model} is not a valid URL or checkpoint file path, "
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
"--model argument must be a URL or checkpoint file path"
)
raise PermissionError(
"GitHub API rate limit exceeded while trying to download the "
"model weights"
) from None
logger.error(error_msg)
raise ValueError(error_msg)

# Log the active configuration.
logger.info("Casanovo version %s", str(__version__))
Expand All @@ -393,7 +406,7 @@ def setup_model(
return config, model


def _get_model_weights() -> str:
def _get_model_weights(cache_dir: Path) -> str:
"""
Use cached model weights or download them from GitHub.

Expand All @@ -407,12 +420,16 @@ def _get_model_weights() -> str:
Note that the GitHub API is limited to 60 requests from the same IP per
hour.

Parameters
----------
cache_dir : Path
model weights cache directory path

Returns
-------
str
The name of the model weights file.
"""
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
os.makedirs(cache_dir, exist_ok=True)
version = utils.split_version(__version__)
version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0
Expand All @@ -436,7 +453,7 @@ def _get_model_weights() -> str:
"Model weights file %s retrieved from local cache",
version_match[0],
)
return version_match[0]
return Path(version_match[0])
# Otherwise try to find compatible model weights on GitHub.
else:
repo = github.Github().get_repo("Noble-Lab/casanovo")
Expand Down Expand Up @@ -469,19 +486,9 @@ def _get_model_weights() -> str:
# Download the model weights if a matching release was found.
if version_match[2] > 0:
filename, url, _ = version_match
logger.info(
"Downloading model weights file %s from %s", filename, url
)
r = requests.get(url, stream=True, allow_redirects=True)
r.raise_for_status()
file_size = int(r.headers.get("Content-Length", 0))
desc = "(Unknown total file size)" if file_size == 0 else ""
r.raw.read = functools.partial(r.raw.read, decode_content=True)
with tqdm.tqdm.wrapattr(
r.raw, "read", total=file_size, desc=desc
) as r_raw, open(filename, "wb") as f:
shutil.copyfileobj(r_raw, f)
return filename
cache_file_path = cache_dir / filename
_download_weights(url, cache_file_path)
return cache_file_path
else:
logger.error(
"No matching model weights for release v%s found, please "
Expand All @@ -496,5 +503,130 @@ def _get_model_weights() -> str:
)


def _get_weights_from_url(
file_url: str,
cache_dir: Path,
force_download: Optional[bool] = False,
) -> Path:
"""
Resolve weight file from URL

Attempt to download weight file from URL if weights are not already
cached - otherwise use cached weights. Downloaded weight files will be
cached.

Parameters
----------
file_url : str
URL pointing to model weights file.
cache_dir : Path
Model weights cache directory path.
force_download : Optional[bool], default=False
If True, forces a new download of the weight file even if it exists in
the cache.

Returns
-------
Path
Path to the cached weights file.
"""
if not _is_valid_url(file_url):
raise ValueError("file_url must point to a valid URL")

os.makedirs(cache_dir, exist_ok=True)
cache_file_name = Path(urllib.parse.urlparse(file_url).path).name
url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5)
cache_file_dir = cache_dir / url_hash
cache_file_path = cache_file_dir / cache_file_name

if cache_file_path.is_file() and not force_download:
cache_time = cache_file_path.stat()
url_last_modified = 0

try:
file_response = requests.head(file_url)
if file_response.ok:
if "Last-Modified" in file_response.headers:
url_last_modified = datetime.datetime.strptime(
file_response.headers["Last-Modified"],
"%a, %d %b %Y %H:%M:%S %Z",
).timestamp()
else:
logger.warning(
"Attempted HEAD request to %s yielded non-ok status code - using cached file",
file_url,
)
except (
requests.ConnectionError,
requests.Timeout,
requests.TooManyRedirects,
):
logger.warning(
"Failed to reach %s to get remote last modified time - using cached file",
file_url,
)

if cache_time.st_mtime > url_last_modified:
Lilferrit marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Model weights %s retrieved from local cache", file_url
)
return cache_file_path

_download_weights(file_url, cache_file_path)
return cache_file_path


def _download_weights(file_url: str, download_path: Path) -> None:
"""
Download weights file from URL

Download the model weights file from the specified URL and save it to the
given path. Ensures the download directory exists, and uses a progress
bar to indicate download status.

Parameters
----------
file_url : str
URL pointing to the model weights file.
download_path : Path
Path where the downloaded weights file will be saved.
"""
download_file_dir = download_path.parent
os.makedirs(download_file_dir, exist_ok=True)
response = requests.get(file_url, stream=True, allow_redirects=True)
response.raise_for_status()
file_size = int(response.headers.get("Content-Length", 0))
desc = "(Unknown total file size)" if file_size == 0 else ""
response.raw.read = functools.partial(
response.raw.read, decode_content=True
)

with tqdm.tqdm.wrapattr(
response.raw, "read", total=file_size, desc=desc
) as r_raw, open(download_path, "wb") as file:
shutil.copyfileobj(r_raw, file)


def _is_valid_url(file_url: str) -> bool:
"""
Determine whether file URL is a valid URL

Parameters
----------
file_url : str
url to verify

Return
------
is_url : bool
whether file_url is a valid url
"""
try:
result = urllib.parse.urlparse(file_url)
return all([result.scheme, result.netloc])
except ValueError:
return False


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions docs/file_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Input file formats for Casanovo

### MS/MS spectra

When you're ready to use Casanovo for *de novo* peptide sequencing, you can input your MS/MS spectra in one of the following formats:

- **[mzML](https://doi.org/10.1074/mcp.R110.000133)**: XML-based mass spectrometry community standard file format developed by the Proteomics Standards Initiative (PSI).
Expand All @@ -11,6 +13,19 @@ When you're ready to use Casanovo for *de novo* peptide sequencing, you can inpu
All three of the above file formats can be used as input to Casanovo for *de novo* peptide sequencing.
As the official PSI standard format containing the complete information from a mass spectrometry run, mzML should typically be preferred.

### Model weights

In addition to MS/MS spectra, Casanovo also optionally accepts a model weights (.ckpt extension) input file when running in training, sequencing, or evaluating mode.
These weights define the functionality of the Casanovo neural network.

If no input weights file is provided, Casanovo will automatically use the most recent compatible weights from the [official Casanovo GitHub repository](https://github.com/Noble-Lab/casanovo), which will be downloaded and cached locally if they are not already.
Model weights are retrieved by matching Casanovo release version, which is of the form (major, minor, patch).
If no model weights for an identical release are available, alternative releases with matching (i) major and minor, or (ii) major versions will be used.

Alternatively, you can input custom model weights in the form of a local file system path or a URL pointing to a compatible Casanovo model weights file.
If a URL is provided, the upstream weights file will be downloaded and cached locally for later use.
See the [command line interface documentation](cli.rst) for more details.

## Output: Understanding the mzTab format

After Casanovo processes your input file(s), it provides the sequencing results in an **[mzTab]((https://doi.org/10.1074/mcp.O113.036681))** file.
Expand Down
Loading