Skip to content

Commit

Permalink
Merge pull request #2 from CBICA/spiros-dev
Browse files Browse the repository at this point in the history
Initial code review for DLWMLS
  • Loading branch information
euroso97 authored Nov 14, 2024
2 parents 2c35c33 + e088660 commit 1cdf459
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 7,568 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: pre-commit

on:
pull_request:
push:
branches: [main, spiros-dev]

jobs:
pre-commit:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: install pre commit
run: pip install pre-commit && pre-commit install
- name: pre-commit
uses: pre-commit/[email protected]
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
repos:
- repo: https://github.com/ambv/black
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
hooks:
- id: mypy
args: ["--ignore-missing-imports"]

41 changes: 15 additions & 26 deletions DLWMLS/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import json
import os
from pathlib import Path
import shutil
import sys
import warnings
from pathlib import Path

import torch

Expand All @@ -15,15 +15,16 @@

VERSION = 1.0


def main() -> None:
prog="DLWMLS"
prog = "DLWMLS"
parser = argparse.ArgumentParser(
prog=prog,
description="DLWMLS - MUlti-atlas region Segmentation utilizing Ensembles of registration algorithms and parameters.",
usage="""
DLWMLS v{VERSION}
Segment White Matter Lesions from the ICV-segmented (see DLICV method), LPS oriented brain image (Nifti/.nii.gz format).
Required arguments:
[-i, --in_dir] The filepath of the input directory
[-o, --out_dir] The filepath of the output directory
Expand All @@ -36,8 +37,10 @@ def main() -> None:
-o /path/to/output \
-device cpu|cuda|mps
""".format(VERSION=VERSION),
add_help=False
""".format(
VERSION=VERSION
),
add_help=False,
)

# Required Arguments
Expand All @@ -53,7 +56,7 @@ def main() -> None:
required=True,
help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).",
)

# Optional Arguments
parser.add_argument(
"-device",
Expand Down Expand Up @@ -99,7 +102,7 @@ def main() -> None:
action="store_true",
required=False,
default=False,
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed."
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed.",
)
parser.add_argument(
"--disable_tta",
Expand All @@ -109,13 +112,6 @@ def main() -> None:
help="[nnUnet Arg] Set this flag to disable test time data augmentation in the form of mirroring. "
"Faster, but less accurate inference. Not recommended.",
)
### DEPRECIATED ####
# parser.add_argument(
# "-m",
# type=str,
# required=True,
# help="Model folder path. The model folder should be named nnunet_results.",
# )
parser.add_argument(
"-d",
type=str,
Expand Down Expand Up @@ -208,25 +204,17 @@ def main() -> None:
required=False,
default=0,
help="[nnUnet Arg] If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"-num_parts 5 and use -part_id 0, 1, 2, 3 and 4. Note: You are yourself responsible to make these run on separate GPUs! "
"Use CUDA_VISIBLE_DEVICES.",
)



args = parser.parse_args()
args.f = [args.f]

if args.clear_cache:
shutil.rmtree(os.path.join(
Path(__file__).parent,
"nnunet_results"
))
shutil.rmtree(os.path.join(
Path(__file__).parent,
".cache"
))
shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results"))
shutil.rmtree(os.path.join(Path(__file__).parent, ".cache"))
if not args.i or not args.o:
print("Cache cleared and missing either -i / -o. Exiting.")
sys.exit(0)
Expand Down Expand Up @@ -263,14 +251,14 @@ def main() -> None:
% (args.d, args.d, args.c),
)


# Check if model exists. If not exist, download using HuggingFace
print(f"Using model folder: {model_folder}")
if not os.path.exists(model_folder):
# HF download model
print("DLWMLS model not found, downloading...")

from huggingface_hub import snapshot_download

local_src = Path(__file__).parent
snapshot_download(repo_id="nichart/DLWMLS", local_dir=local_src)

Expand All @@ -292,6 +280,7 @@ def main() -> None:

if args.device == "cpu":
import multiprocessing

# use half of the available threads in the system.
torch.set_num_threads(multiprocessing.cpu_count() // 2)
device = torch.device("cpu")
Expand Down

This file was deleted.

Loading

0 comments on commit 1cdf459

Please sign in to comment.