Skip to content

Commit

Permalink
Merge pull request #51 from BrainLesion/46-feature-request-improve-we…
Browse files Browse the repository at this point in the history
…ights-download-logic-and-robustness

46 feature request improve weights download logic and robustness
  • Loading branch information
MarcelRosier authored Jun 25, 2024
2 parents 5ad95d1 + cb54798 commit b534519
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 55 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,5 @@ dmypy.json
.vscode
poetry.lock
.DS_Store
brainles_aurora/model_weights/*
brainles_aurora/model_weights/*
brainles_aurora/weights/*
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pip install brainles-aurora
## Recommended Environment

- CUDA 11.4+ (https://developer.nvidia.com/cuda-toolkit)
- Python 3.10+
- GPU with CUDA support and at least 8GB of VRAM
- Python 3.8+
- GPU with CUDA support and at least 6GB of VRAM

## Usage
BrainLes features Jupyter Notebook [tutorials](https://github.com/BrainLesion/tutorials/tree/main/AURORA) with usage instructions.
Expand Down
4 changes: 2 additions & 2 deletions brainles_aurora/inferer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ class Device(str, Enum):
"""Attempt to use GPU, fallback to CPU."""


WEIGHTS_DIR = "weights"
"""Directory name to store model weights."""
WEIGHTS_DIR_PATTERN = "weights_v*.*.*"
"""Directory name pattern to store model weights. E.g. weights_v1.0.0"""
2 changes: 1 addition & 1 deletion brainles_aurora/inferer/log_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"disable_existing_loggers": false,
"formatters": {
"simple": {
"format": "[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s",
"format": "[%(levelname)s] %(asctime)s: %(message)s",
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
}
},
Expand Down
13 changes: 5 additions & 8 deletions brainles_aurora/inferer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import numpy as np
import torch
from brainles_aurora.inferer.config import AuroraInfererConfig
from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR
from brainles_aurora.inferer.data import DataHandler
from brainles_aurora.utils import download_model_weights
from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR_PATTERN
from brainles_aurora.utils import check_model_weights
from monai.inferers import SlidingWindowInferer
from monai.networks.nets import BasicUNet
from monai.transforms import RandGaussianNoised
Expand Down Expand Up @@ -39,11 +38,9 @@ def __init__(
# Will be set during infer() call
self.model = None
self.inference_mode = None
# download weights if not present
self.lib_path: str = Path(os.path.dirname(os.path.abspath(__file__)))
self.model_weights_folder = self.lib_path.parent / WEIGHTS_DIR
if not self.model_weights_folder.exists():
download_model_weights(target_folder=str(self.model_weights_folder))

# get location of model weights
self.model_weights_folder = check_model_weights()

def load_model(
self, inference_mode: InferenceMode, num_input_modalities: int
Expand Down
2 changes: 1 addition & 1 deletion brainles_aurora/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .utils import remove_path_suffixes
from .download import download_model_weights
from .weights import check_model_weights
39 changes: 0 additions & 39 deletions brainles_aurora/utils/download.py

This file was deleted.

170 changes: 170 additions & 0 deletions brainles_aurora/utils/weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import logging
import shutil
import sys
import zipfile
from io import BytesIO
from pathlib import Path
from typing import Dict, List

import requests
from tqdm import tqdm
from brainles_aurora.inferer.constants import WEIGHTS_DIR_PATTERN

logger = logging.getLogger(__name__)

ZENODO_RECORD_URL = "https://zenodo.org/api/records/10557068"


def check_model_weights() -> Path:
"""Check if latest model weights are present and download them otherwise.
Returns:
Path: Path to the model weights folder.
"""
package_folder = Path(__file__).parent.parent

zenodo_metadata, archive_url = _get_zenodo_metadata_and_archive_url()

matching_folders = list(package_folder.glob(WEIGHTS_DIR_PATTERN))
# Get the latest downloaded weights
latest_downloaded_weights = _get_latest_version_folder_name(matching_folders)

if not latest_downloaded_weights:
if not zenodo_metadata:
logger.error(
"Model weights not found locally and Zenodo could not be reached. Exiting..."
)
sys.exit()
logger.info(
f"Model weights not found. Downloading the latest model weights {zenodo_metadata['version']} from Zenodo..."
)

return _download_model_weights(
package_folder=package_folder,
zenodo_metadata=zenodo_metadata,
archive_url=archive_url,
)

logger.info(f"Found downloaded local weights: {latest_downloaded_weights}")

if not zenodo_metadata:
logger.warning(
"Zenodo server could not be reached. Using the latest downloaded weights."
)
return package_folder / latest_downloaded_weights

# Compare the latest downloaded weights with the latest Zenodo version
if zenodo_metadata["version"] == latest_downloaded_weights.split("_v")[1]:
logger.info(
f"Latest model weights ({latest_downloaded_weights}) are already present."
)
return package_folder / latest_downloaded_weights

logger.info(
f"New model weights available on Zenodo ({zenodo_metadata['version']}). Deleting old and fetching new weights..."
)
# delete old weights
try:
shutil.rmtree(
package_folder / latest_downloaded_weights,
)
except Exception as e:
logger.warning(
f"Failed to delete {package_folder / latest_downloaded_weights}: {e}"
),

return _download_model_weights(
package_folder=package_folder,
zenodo_metadata=zenodo_metadata,
archive_url=archive_url,
)


def _get_latest_version_folder_name(folders: List[Path]) -> str | None:
"""Get the latest (non empty) version folder name from the list of folders.
Args:
folders (List[Path]): List of folders matching the pattern.
Returns:
str | None: Latest version folder name if one exists, else None.
"""
if not folders:
return None
latest_downloaded_folder = sorted(
folders,
reverse=True,
key=lambda x: tuple(map(int, str(x).split("_v")[1].split("."))),
)[0]
# check folder is not empty
if not list(latest_downloaded_folder.glob("*")):
return None
return latest_downloaded_folder.name


def _get_zenodo_metadata_and_archive_url() -> Dict | None:
"""Get the metadata for the Zenodo record and the files archive url.
Returns:
Tuple: (dict: Metadata for the Zenodo record, str: URL to the archive file)
"""
try:
response = requests.get(ZENODO_RECORD_URL)
if response.status_code != 200:
logger.error(f"Cant find model weights on Zenodo. Exiting...")
data = response.json()
return data["metadata"], data["links"]["archive"]

except requests.exceptions.RequestException as e:
logger.warning(f"Failed to fetch Zenodo metadata: {e}")
return None


def _download_model_weights(
package_folder: Path, zenodo_metadata: Dict, archive_url: str
) -> Path:
"""Download the latest model weights from Zenodo and extract them to the target folder.
Args:
package_folder (Path): Package folder path in which the model weights will be stored.
zenodo_metadata (Dict): Metadata for the Zenodo record.
archive_url (str): URL to the model weights archive file.
Returns:
Path: Path to the model weights folder.
"""
weights_folder = package_folder / f"weights_v{zenodo_metadata['version']}"

# ensure folder exists
weights_folder.mkdir(parents=True, exist_ok=True)

logger.info(
f"Downloading model weights from Zenodo ({ZENODO_RECORD_URL}). This might take a while..."
)
# Make a GET request to the URL
response = requests.get(archive_url, stream=True)
# Ensure the request was successful
if response.status_code != 200:
logger.error(
f"Failed to download model weights. Status code: {response.status_code}"
)
return
# Download with progress bar
chunk_size = 1024 # 1KB
bytes_io = BytesIO()
with tqdm(
total=0, # unknown size since content length not given
unit="B",
unit_scale=True,
) as pbar:
for data in response.iter_content(chunk_size=chunk_size):
bytes_io.write(data)
pbar.update(len(data))

# Extract the downloaded zip file to the target folder
with zipfile.ZipFile(bytes_io) as zip_ref:
zip_ref.extractall(weights_folder)
logger.info(f"Zip file extracted successfully to {weights_folder}")
return weights_folder
2 changes: 1 addition & 1 deletion example/segmentation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def gpu_np():


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
# logging.basicConfig(level=)
gpu_nifti()
# gpu_nifti_2()

0 comments on commit b534519

Please sign in to comment.