From 68379def49dcb365225d6127c34645eeb46c38d0 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Tue, 30 Apr 2024 18:00:04 +0200 Subject: [PATCH 1/7] - enhance download weights logic --- brainles_aurora/inferer/constants.py | 4 +- brainles_aurora/utils/__init__.py | 2 +- brainles_aurora/utils/download.py | 39 --------- brainles_aurora/utils/weights.py | 114 +++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 42 deletions(-) delete mode 100644 brainles_aurora/utils/download.py create mode 100644 brainles_aurora/utils/weights.py diff --git a/brainles_aurora/inferer/constants.py b/brainles_aurora/inferer/constants.py index 7baae07..47044d9 100644 --- a/brainles_aurora/inferer/constants.py +++ b/brainles_aurora/inferer/constants.py @@ -82,5 +82,5 @@ class Device(str, Enum): """Attempt to use GPU, fallback to CPU.""" -WEIGHTS_DIR = "weights" -"""Directory name to store model weights.""" +WEIGHTS_DIR_PATTERN = "weights_v*.*.*" +"""Directory name pattern to store model weights. E.g. weights_v1.0.0""" diff --git a/brainles_aurora/utils/__init__.py b/brainles_aurora/utils/__init__.py index 44a2a61..0fdc3f9 100644 --- a/brainles_aurora/utils/__init__.py +++ b/brainles_aurora/utils/__init__.py @@ -1,2 +1,2 @@ from .utils import remove_path_suffixes -from .download import download_model_weights +from .weights import download_model_weights diff --git a/brainles_aurora/utils/download.py b/brainles_aurora/utils/download.py deleted file mode 100644 index ba6f600..0000000 --- a/brainles_aurora/utils/download.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import logging -import os -import zipfile -from pathlib import Path -from io import BytesIO - - -import requests - -logger = logging.getLogger(__name__) - -DOWNLOAD_URL = "https://zenodo.org/api/records/10557069/files-archive" - - -def download_model_weights(target_folder: str | Path) -> None: - """Download the model weights from Zenodo and extract them to the target folder. - - Args: - target_folder (str | Path): The folder to which the model weights should be downloaded and extracted to. - """ - # Create the target folder if it does not exist - os.makedirs(target_folder, exist_ok=True) - logger.info( - f"Downloading model weights from Zenodo ({DOWNLOAD_URL}). This might take a while..." - ) - # Make a GET request to the URL - response = requests.get(DOWNLOAD_URL) - # Ensure the request was successful - if response.status_code != 200: - logger.error( - f"Failed to download model weights from {DOWNLOAD_URL}. Status code: {response.status_code}" - ) - return - # Extract the downloaded zip file to the target folder - with zipfile.ZipFile(BytesIO(response.content)) as zip_ref: - zip_ref.extractall(target_folder) - logger.info(f"Zip file extracted successfully to {target_folder}") diff --git a/brainles_aurora/utils/weights.py b/brainles_aurora/utils/weights.py new file mode 100644 index 0000000..4c120bf --- /dev/null +++ b/brainles_aurora/utils/weights.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging +import os +import zipfile +from io import BytesIO +from pathlib import Path +from typing import Dict + +import requests +from brainles_aurora.inferer.constants import WEIGHTS_DIR_PATTERN +import sys + +logger = logging.getLogger(__name__) + +ZENODO_RECORD_URL = "https://zenodo.org/api/records/10557069" + + +def check_model_weights(package_folder: Path) -> Path: + """Check if latest model weights are present and download otherwise. + + Args: + package_folder (Path): Package folder path in which the model weights are stored. + + Returns: + Path: Path to the model weights folder. + """ + zenodo_metadata = _get_zenodo_metadata() + + matching_folders = list(package_folder.glob(WEIGHTS_DIR_PATTERN)) + if not matching_folders: + if not zenodo_metadata: + logger.error( + "Model weights not found locally and Zenodo could not be reached. Exiting..." + ) + sys.exit() + logger.info( + f"Model weights not found. Downloading the latest model weights {zenodo_metadata['version']} from Zenodo..." + ) + + return download_model_weights( + package_folder=package_folder, zenodo_metadata=zenodo_metadata + ) + + # Get the latest downloaded weights + latest_downloaded_weights = sorted( + matching_folders, + reverse=True, + key=lambda x: tuple(map(int, x.split("_v")[1].split("."))), + )[0] + + if not zenodo_metadata: + logger.warning( + "Zenodo server could not be reached. Using the latest downloaded weights." + ) + return package_folder / latest_downloaded_weights + + # Compare the latest downloaded weights with the latest Zenodo version + if zenodo_metadata["version"] == latest_downloaded_weights.split("_v")[1]: + logger.info( + f"Latest model weights ({latest_downloaded_weights}) are already present." + ) + return package_folder / latest_downloaded_weights + + logger.info( + f"New model weights available on Zenodo ({zenodo_metadata['version']}). Downloading..." + ) + return download_model_weights( + package_folder=package_folder, zenodo_metadata=zenodo_metadata + ) + + +def _get_zenodo_metadata() -> Dict | None: + """Get the metadata for the Zenodo record. + + Returns: + dict: Metadata for the Zenodo record. + """ + try: + response = requests.get(ZENODO_RECORD_URL) + return response.json()["metadata"] + except requests.exceptions.RequestException as e: + logger.warning(f"Failed to fetch Zenodo metadata: {e}") + return None + + +def download_model_weights(package_folder: Path, zenodo_metadata: Dict) -> None: + """Download the latest model weights from Zenodo and extract them to the target folder. + + Args: + package_folder (Path): Package folder path in which the model weights will be stored. + zenodo_metadata (Dict): Metadata for the Zenodo record. + """ + weights_folder = package_folder / f"weights_v{zenodo_metadata['version']}" + + # ensure folder exists + weights_folder.mkdir(parents=True, exist_ok=True) + + logger.info( + f"Downloading model weights from Zenodo ({ZENODO_RECORD_URL}). This might take a while..." + ) + # Make a GET request to the URL + response = requests.get(f"{ZENODO_RECORD_URL}/files-archive") + # Ensure the request was successful + if response.status_code != 200: + logger.error( + f"Failed to download model weights. Status code: {response.status_code}" + ) + return + # Extract the downloaded zip file to the target folder + with zipfile.ZipFile(BytesIO(response.content)) as zip_ref: + zip_ref.extractall(weights_folder) + logger.info(f"Zip file extracted successfully to {weights_folder}") + return weights_folder From 2e9a441ac7c41bfc1e734445093f6b3696e299d3 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Tue, 30 Apr 2024 18:07:43 +0200 Subject: [PATCH 2/7] add deletion of old weight folders --- brainles_aurora/utils/weights.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/brainles_aurora/utils/weights.py b/brainles_aurora/utils/weights.py index 4c120bf..eee9a7e 100644 --- a/brainles_aurora/utils/weights.py +++ b/brainles_aurora/utils/weights.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -import os +import shutil +import sys import zipfile from io import BytesIO from pathlib import Path @@ -9,7 +10,6 @@ import requests from brainles_aurora.inferer.constants import WEIGHTS_DIR_PATTERN -import sys logger = logging.getLogger(__name__) @@ -63,7 +63,15 @@ def check_model_weights(package_folder: Path) -> Path: return package_folder / latest_downloaded_weights logger.info( - f"New model weights available on Zenodo ({zenodo_metadata['version']}). Downloading..." + f"New model weights available on Zenodo ({zenodo_metadata['version']}). Deleting old and fetching new weights..." + ) + # delete old weights + shutil.rmtree( + package_folder / "testestsetestset", + ignore_errors=True, + onerror=lambda func, path, excinfo: logger.warning( + f"Failed to delete {path}: {excinfo}" + ), ) return download_model_weights( package_folder=package_folder, zenodo_metadata=zenodo_metadata From 277cbeb969c4cc664b942c0e2327ed7fe4a5d323 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Tue, 30 Apr 2024 18:38:28 +0200 Subject: [PATCH 3/7] fix old weights deletion fix issue on aborted download/ empty latest folder integrate into model class --- brainles_aurora/inferer/model.py | 10 ++---- brainles_aurora/utils/__init__.py | 2 +- brainles_aurora/utils/weights.py | 51 +++++++++++++++++++------------ 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/brainles_aurora/inferer/model.py b/brainles_aurora/inferer/model.py index 97b6154..e0be2f4 100644 --- a/brainles_aurora/inferer/model.py +++ b/brainles_aurora/inferer/model.py @@ -8,9 +8,8 @@ import numpy as np import torch from brainles_aurora.inferer.config import AuroraInfererConfig -from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR -from brainles_aurora.inferer.data import DataHandler -from brainles_aurora.utils import download_model_weights +from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR_PATTERN +from brainles_aurora.utils import check_model_weights from monai.inferers import SlidingWindowInferer from monai.networks.nets import BasicUNet from monai.transforms import RandGaussianNoised @@ -40,10 +39,7 @@ def __init__( self.model = None self.inference_mode = None # download weights if not present - self.lib_path: str = Path(os.path.dirname(os.path.abspath(__file__))) - self.model_weights_folder = self.lib_path.parent / WEIGHTS_DIR - if not self.model_weights_folder.exists(): - download_model_weights(target_folder=str(self.model_weights_folder)) + self.model_weights_folder = check_model_weights() def load_model( self, inference_mode: InferenceMode, num_input_modalities: int diff --git a/brainles_aurora/utils/__init__.py b/brainles_aurora/utils/__init__.py index 0fdc3f9..4d0cb76 100644 --- a/brainles_aurora/utils/__init__.py +++ b/brainles_aurora/utils/__init__.py @@ -1,2 +1,2 @@ from .utils import remove_path_suffixes -from .weights import download_model_weights +from .weights import check_model_weights diff --git a/brainles_aurora/utils/weights.py b/brainles_aurora/utils/weights.py index eee9a7e..413fd84 100644 --- a/brainles_aurora/utils/weights.py +++ b/brainles_aurora/utils/weights.py @@ -6,7 +6,7 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import Dict +from typing import Dict, List import requests from brainles_aurora.inferer.constants import WEIGHTS_DIR_PATTERN @@ -16,19 +16,21 @@ ZENODO_RECORD_URL = "https://zenodo.org/api/records/10557069" -def check_model_weights(package_folder: Path) -> Path: - """Check if latest model weights are present and download otherwise. - - Args: - package_folder (Path): Package folder path in which the model weights are stored. +def check_model_weights() -> Path: + """Check if latest model weights are present and download them otherwise. Returns: Path: Path to the model weights folder. """ + package_folder = Path(__file__).parent.parent + zenodo_metadata = _get_zenodo_metadata() matching_folders = list(package_folder.glob(WEIGHTS_DIR_PATTERN)) - if not matching_folders: + # Get the latest downloaded weights + latest_downloaded_weights = _get_latest_version_folder_name(matching_folders) + + if not latest_downloaded_weights: if not zenodo_metadata: logger.error( "Model weights not found locally and Zenodo could not be reached. Exiting..." @@ -38,16 +40,11 @@ def check_model_weights(package_folder: Path) -> Path: f"Model weights not found. Downloading the latest model weights {zenodo_metadata['version']} from Zenodo..." ) - return download_model_weights( + return _download_model_weights( package_folder=package_folder, zenodo_metadata=zenodo_metadata ) - # Get the latest downloaded weights - latest_downloaded_weights = sorted( - matching_folders, - reverse=True, - key=lambda x: tuple(map(int, x.split("_v")[1].split("."))), - )[0] + logger.info(f"Found downloaded local weights: {latest_downloaded_weights}") if not zenodo_metadata: logger.warning( @@ -67,17 +64,30 @@ def check_model_weights(package_folder: Path) -> Path: ) # delete old weights shutil.rmtree( - package_folder / "testestsetestset", - ignore_errors=True, + package_folder / latest_downloaded_weights, onerror=lambda func, path, excinfo: logger.warning( f"Failed to delete {path}: {excinfo}" ), ) - return download_model_weights( + return _download_model_weights( package_folder=package_folder, zenodo_metadata=zenodo_metadata ) +def _get_latest_version_folder_name(folders: List[Path]) -> str | None: + if not folders: + return None + latest_downloaded_folder = sorted( + folders, + reverse=True, + key=lambda x: tuple(map(int, str(x).split("_v")[1].split("."))), + )[0] + # check folder is not empty + if not list(latest_downloaded_folder.glob("*")): + return None + return latest_downloaded_folder.name + + def _get_zenodo_metadata() -> Dict | None: """Get the metadata for the Zenodo record. @@ -92,12 +102,15 @@ def _get_zenodo_metadata() -> Dict | None: return None -def download_model_weights(package_folder: Path, zenodo_metadata: Dict) -> None: +def _download_model_weights(package_folder: Path, zenodo_metadata: Dict) -> Path: """Download the latest model weights from Zenodo and extract them to the target folder. Args: package_folder (Path): Package folder path in which the model weights will be stored. - zenodo_metadata (Dict): Metadata for the Zenodo record. + zenodo_metadata (Dict): Metadata for the Zenodo record. + + Returns: + Path: Path to the model weights folder. """ weights_folder = package_folder / f"weights_v{zenodo_metadata['version']}" From f0285da39a835f8ea8a3669188d7adee44ad5b26 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Wed, 1 May 2024 12:16:00 +0200 Subject: [PATCH 4/7] simplify console log format --- brainles_aurora/inferer/log_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainles_aurora/inferer/log_config.json b/brainles_aurora/inferer/log_config.json index 7b96edd..06cfe19 100644 --- a/brainles_aurora/inferer/log_config.json +++ b/brainles_aurora/inferer/log_config.json @@ -3,7 +3,7 @@ "disable_existing_loggers": false, "formatters": { "simple": { - "format": "[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s", + "format": "[%(levelname)s] %(asctime)s: %(message)s", "datefmt": "%Y-%m-%dT%H:%M:%S%z" } }, From 16b2375dc10f2e3cda94a861243f8ff412b39cb5 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Wed, 1 May 2024 12:16:12 +0200 Subject: [PATCH 5/7] add docstring --- brainles_aurora/utils/weights.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/brainles_aurora/utils/weights.py b/brainles_aurora/utils/weights.py index 413fd84..5ac2f8e 100644 --- a/brainles_aurora/utils/weights.py +++ b/brainles_aurora/utils/weights.py @@ -75,6 +75,14 @@ def check_model_weights() -> Path: def _get_latest_version_folder_name(folders: List[Path]) -> str | None: + """Get the latest (non empty) version folder name from the list of folders. + + Args: + folders (List[Path]): List of folders matching the pattern. + + Returns: + str | None: Latest version folder name if one exists, else None. + """ if not folders: return None latest_downloaded_folder = sorted( From 31cebd800f292b76719d55e3a7a09eb9dab1d7f9 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Wed, 1 May 2024 12:25:59 +0200 Subject: [PATCH 6/7] fix comment --- brainles_aurora/inferer/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/brainles_aurora/inferer/model.py b/brainles_aurora/inferer/model.py index e0be2f4..95449da 100644 --- a/brainles_aurora/inferer/model.py +++ b/brainles_aurora/inferer/model.py @@ -38,7 +38,8 @@ def __init__( # Will be set during infer() call self.model = None self.inference_mode = None - # download weights if not present + + # get location of model weights self.model_weights_folder = check_model_weights() def load_model( From cb547981f99af49ff43cfa728a5f8ad44055fc64 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Tue, 25 Jun 2024 15:43:45 +0200 Subject: [PATCH 7/7] fixed weights update logic fixed deprecated warning --- .gitignore | 3 +- README.md | 4 +-- brainles_aurora/utils/weights.py | 59 +++++++++++++++++++++++--------- example/segmentation_example.py | 2 +- 4 files changed, 48 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index a5e1bd8..72837a3 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .vscode poetry.lock .DS_Store -brainles_aurora/model_weights/* \ No newline at end of file +brainles_aurora/model_weights/* +brainles_aurora/weights/* \ No newline at end of file diff --git a/README.md b/README.md index 8ddf2f0..2f3da41 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ pip install brainles-aurora ## Recommended Environment - CUDA 11.4+ (https://developer.nvidia.com/cuda-toolkit) -- Python 3.10+ -- GPU with CUDA support and at least 8GB of VRAM +- Python 3.8+ +- GPU with CUDA support and at least 6GB of VRAM ## Usage BrainLes features Jupyter Notebook [tutorials](https://github.com/BrainLesion/tutorials/tree/main/AURORA) with usage instructions. diff --git a/brainles_aurora/utils/weights.py b/brainles_aurora/utils/weights.py index 5ac2f8e..1051a46 100644 --- a/brainles_aurora/utils/weights.py +++ b/brainles_aurora/utils/weights.py @@ -9,11 +9,12 @@ from typing import Dict, List import requests +from tqdm import tqdm from brainles_aurora.inferer.constants import WEIGHTS_DIR_PATTERN logger = logging.getLogger(__name__) -ZENODO_RECORD_URL = "https://zenodo.org/api/records/10557069" +ZENODO_RECORD_URL = "https://zenodo.org/api/records/10557068" def check_model_weights() -> Path: @@ -24,7 +25,7 @@ def check_model_weights() -> Path: """ package_folder = Path(__file__).parent.parent - zenodo_metadata = _get_zenodo_metadata() + zenodo_metadata, archive_url = _get_zenodo_metadata_and_archive_url() matching_folders = list(package_folder.glob(WEIGHTS_DIR_PATTERN)) # Get the latest downloaded weights @@ -41,7 +42,9 @@ def check_model_weights() -> Path: ) return _download_model_weights( - package_folder=package_folder, zenodo_metadata=zenodo_metadata + package_folder=package_folder, + zenodo_metadata=zenodo_metadata, + archive_url=archive_url, ) logger.info(f"Found downloaded local weights: {latest_downloaded_weights}") @@ -63,14 +66,19 @@ def check_model_weights() -> Path: f"New model weights available on Zenodo ({zenodo_metadata['version']}). Deleting old and fetching new weights..." ) # delete old weights - shutil.rmtree( - package_folder / latest_downloaded_weights, - onerror=lambda func, path, excinfo: logger.warning( - f"Failed to delete {path}: {excinfo}" + try: + shutil.rmtree( + package_folder / latest_downloaded_weights, + ) + except Exception as e: + logger.warning( + f"Failed to delete {package_folder / latest_downloaded_weights}: {e}" ), - ) + return _download_model_weights( - package_folder=package_folder, zenodo_metadata=zenodo_metadata + package_folder=package_folder, + zenodo_metadata=zenodo_metadata, + archive_url=archive_url, ) @@ -96,26 +104,33 @@ def _get_latest_version_folder_name(folders: List[Path]) -> str | None: return latest_downloaded_folder.name -def _get_zenodo_metadata() -> Dict | None: - """Get the metadata for the Zenodo record. +def _get_zenodo_metadata_and_archive_url() -> Dict | None: + """Get the metadata for the Zenodo record and the files archive url. Returns: - dict: Metadata for the Zenodo record. + Tuple: (dict: Metadata for the Zenodo record, str: URL to the archive file) """ try: response = requests.get(ZENODO_RECORD_URL) - return response.json()["metadata"] + if response.status_code != 200: + logger.error(f"Cant find model weights on Zenodo. Exiting...") + data = response.json() + return data["metadata"], data["links"]["archive"] + except requests.exceptions.RequestException as e: logger.warning(f"Failed to fetch Zenodo metadata: {e}") return None -def _download_model_weights(package_folder: Path, zenodo_metadata: Dict) -> Path: +def _download_model_weights( + package_folder: Path, zenodo_metadata: Dict, archive_url: str +) -> Path: """Download the latest model weights from Zenodo and extract them to the target folder. Args: package_folder (Path): Package folder path in which the model weights will be stored. zenodo_metadata (Dict): Metadata for the Zenodo record. + archive_url (str): URL to the model weights archive file. Returns: Path: Path to the model weights folder. @@ -129,15 +144,27 @@ def _download_model_weights(package_folder: Path, zenodo_metadata: Dict) -> Path f"Downloading model weights from Zenodo ({ZENODO_RECORD_URL}). This might take a while..." ) # Make a GET request to the URL - response = requests.get(f"{ZENODO_RECORD_URL}/files-archive") + response = requests.get(archive_url, stream=True) # Ensure the request was successful if response.status_code != 200: logger.error( f"Failed to download model weights. Status code: {response.status_code}" ) return + # Download with progress bar + chunk_size = 1024 # 1KB + bytes_io = BytesIO() + with tqdm( + total=0, # unknown size since content length not given + unit="B", + unit_scale=True, + ) as pbar: + for data in response.iter_content(chunk_size=chunk_size): + bytes_io.write(data) + pbar.update(len(data)) + # Extract the downloaded zip file to the target folder - with zipfile.ZipFile(BytesIO(response.content)) as zip_ref: + with zipfile.ZipFile(bytes_io) as zip_ref: zip_ref.extractall(weights_folder) logger.info(f"Zip file extracted successfully to {weights_folder}") return weights_folder diff --git a/example/segmentation_example.py b/example/segmentation_example.py index 8dc7537..e2b980d 100644 --- a/example/segmentation_example.py +++ b/example/segmentation_example.py @@ -74,6 +74,6 @@ def gpu_np(): if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) + # logging.basicConfig(level=) gpu_nifti() # gpu_nifti_2()