diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 0000000..9badd79 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,27 @@ +name: Test JSON Manifests + +on: + pull_request: + types: [opened, synchronize, reopened, review_requested, edited, closed] + push: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install . + - name: Run tests + run: | + python -m pytest -v -s tests/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index df3131b..368974a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,57 @@ -# Add local config files / secrets here to ignore them when you commit. +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Sphinx documentation +docs/_build/ + +.DS_Store \ No newline at end of file diff --git a/aiod_registry/__init__.py b/aiod_registry/__init__.py new file mode 100644 index 0000000..45e0670 --- /dev/null +++ b/aiod_registry/__init__.py @@ -0,0 +1 @@ +from aiod_registry.schema import ModelManifest diff --git a/aiod_registry/manifests/mitonet.json b/aiod_registry/manifests/mitonet.json new file mode 100644 index 0000000..f2ce70f --- /dev/null +++ b/aiod_registry/manifests/mitonet.json @@ -0,0 +1,71 @@ +{ + "name": "Mitonet", + "versions": [ + { + "name": "MitoNet v1", + "tasks": [ + { + "task": "mito", + "location": "https://zenodo.org/record/6861565/files/MitoNet_v1.pth?download=1" + } + ] + }, + { + "name": "MitoNet Mini v1", + "tasks": [ + { + "task": "mito", + "location": "https://zenodo.org/record/6861565/files/MitoNet_v1_mini.pth?download=1" + } + ] + } + ], + "params": [ + { + "name": "Plane", + "value": ["XY", "XZ", "YZ", "All"], + "tooltip": "Whether to use all planes (XY, XZ, YZ) or a single plane" + }, + { + "name": "Downsampling", + "value": [1, 2, 4, 8, 16, 32, 64], + "tooltip": "Downsampling factor for the input image" + }, + { + "name": "Segmentation threshold", + "short_name": "conf_threshold", + "value": 0.5, + "tooltip": "Confidence threshold for the segmentation" + }, + { + "name": "Center threshold", + "short_name": "center_threshold", + "value": 0.1, + "tooltip": "Confidence threshold for the center" + }, + { + "name": "Minimum distance", + "short_name": "min_distance", + "value": 3, + "tooltip": "Minimum distance between object centers" + }, + { + "name": "Maximum objects", + "short_name": "max_objects", + "value": 1000, + "tooltip": "Maximum number of objects to segment per class" + }, + { + "name": "Semantic only", + "short_name": "semantic_only", + "value": false, + "tooltip": "Only run semantic segmentation for all classes" + }, + { + "name": "Fine boundaries", + "short_name": "fine_boundaries", + "value": false, + "tooltip": "Finer boundaries between objects" + } + ] + } \ No newline at end of file diff --git a/aiod_registry/manifests/sam_test.json b/aiod_registry/manifests/sam_test.json new file mode 100644 index 0000000..967a07c --- /dev/null +++ b/aiod_registry/manifests/sam_test.json @@ -0,0 +1,124 @@ +{ + "name": "Segment Anything", + "short_name": "sam", + "versions": [ + { + "name": "default", + "tasks": [ + { + "task": "everything", + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "config_path": null + } + ] + }, + { + "name": "vit_h", + "tasks": [ + { + "task": "everything", + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "config_path": null + } + ] + }, + { + "name": "vit_l", + "tasks": [ + { + "task": "everything", + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "config_path": null + } + ] + }, + { + "name": "vit_b", + "tasks": [ + { + "task": "Mito", + "location":"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "config_path": null + } + ] + }, + { + "name": "MedSAM", + "tasks": [ + { + "task": "everything", + "location": "https://syncandshare.desy.de/index.php/s/yLfdFbpfEGSHJWY/download/medsam_20230423_vit_b_0.0.1.pth", + "config_path": null + } + ] + } + ], + "params": [ + { + "name": "Points per side", + "short_name": "points_per_side", + "value": 32, + "tooltip": "" + }, + { + "name": "Points per batch", + "short_name": "points_per_batch", + "value": 64, + "tooltip": "" + }, + { + "name": "Pred IoU threshold", + "short_name": "pred_iou_thresh", + "value": 0.88, + "tooltip": "" + }, + { + "name": "Stability score threshold", + "short_name": "stability_score_thresh", + "value": 0.95, + "tooltip": "" + }, + { + "name": "Stability score offset", + "short_name": "stability_score_offset", + "value": 1, + "tooltip": "" + }, + { + "name": "Box nms_thresh", + "short_name": "box_nms_thresh", + "value": 0.7, + "tooltip": "" + }, + { + "name": "Crop N layers", + "short_name": "crop_n_layers", + "value": 0, + "tooltip": "" + }, + { + "name": "Crop NMS thresh", + "short_name": "crop_nms_thresh", + "value": 0.7, + "tooltip": "" + }, + { + "name": "Crop overlap ratio", + "short_name": "crop_overlap_ratio", + "value": 0.34133, + "tooltip": "" + }, + { + "name": "Crop B points downscale factor", + "short_name": "crop_n_points_downscale_factor", + "value": 0.5, + "tooltip": "" + }, + { + "name": "Min mask region area", + "short_name": "min_mask_region_area", + "value": 3, + "tooltip": "" + } + ] +} \ No newline at end of file diff --git a/aiod_registry/manifests/unet_seai.json b/aiod_registry/manifests/unet_seai.json new file mode 100644 index 0000000..8de2444 --- /dev/null +++ b/aiod_registry/manifests/unet_seai.json @@ -0,0 +1,31 @@ +{ + "name": "SEAI U-Net", + "short_name": "seai_unet", + "versions": [ + { + "name": "U-Net", + "tasks": [ + { + "task": "mito", + "location": "/nemo/stp/ddt/working/shandc/aiod_models/mito_5nm_intensity_augs_warp.best.969.pt", + "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/mito_5nm_intensity_augs_warp.yml" + } + ] + }, + { + "name": "Attention U-Net", + "tasks": [ + { + "task": "mito", + "location": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_3e5_Adam_restart_12_16.best.1266.pt", + "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_3e5_Adam_restart_12_16.yml" + }, + { + "task": "ne", + "location": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_NE.best.368.pt", + "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_NE.yml" + } + ] + } + ] +} \ No newline at end of file diff --git a/aiod_registry/schema.py b/aiod_registry/schema.py new file mode 100644 index 0000000..c03578a --- /dev/null +++ b/aiod_registry/schema.py @@ -0,0 +1,55 @@ +import json +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator, AnyUrl + + +def shorten_name(name: str) -> str: + return "_".join(name.lower().split(" ")) + + +class StrictModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ModelVersionTask(StrictModel): + # Regex pattern to match task names, ignoring case + task: str = Field(..., pattern=r"^(?i:mito|er|ne|everything)$") + location: Union[Path, AnyUrl, str] = Field( + ..., + description="Either a url or a filepath (will be skipped if the path does not exist/cannot be read)", + ) + config_path: Optional[Union[Path, str]] = None + + +class ModelVersion(StrictModel): + name: str = Field(..., min_length=1, max_length=50) + tasks: list[ModelVersionTask] + + +class ModelParam(StrictModel): + name: str = Field(..., min_length=1, max_length=50) + short_name: Optional[str] = None + value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] + tooltip: Optional[str] = None + + @model_validator(mode="after") + def create_short_name(self): + if self.short_name is None: + self.short_name = shorten_name(self.name) + return self + + +class ModelManifest(StrictModel): + name: str = Field(..., min_length=1, max_length=50) + short_name: Optional[str] = None + versions: list[ModelVersion] + params: Optional[list[ModelParam]] = None + config: Optional[Path] = None + + @model_validator(mode="after") + def create_short_name(self): + if self.short_name is None: + self.short_name = shorten_name(self.name) + return self diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bfb3aeb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +build-backend = 'setuptools.build_meta' +requires = [ + 'setuptools>=61.0.0', +] + +[project] +name = "aiod-registry" +description = "A registry of models for use with AI OnDemand (AIoD)" +version = "0.0.1" +authors = [ + {name = "Cameron Shand", email = "cameron.shand@crick.ac.uk"}, + {name = "Jon Smith", email = "jon.smith@crick.ac.uk"} +] +dependencies = [ + "pytest", + "pydantic >= 2.0" +] +requires-python = ">=3.9" +dynamic = ["readme"] + +[tool.setuptools.dynamic] +readme = {file = ["README.md"], content-type = "text/markdown"} \ No newline at end of file diff --git a/tests/test_jsons.py b/tests/test_jsons.py new file mode 100644 index 0000000..a4be928 --- /dev/null +++ b/tests/test_jsons.py @@ -0,0 +1,23 @@ +import json +from pathlib import Path + +from pydantic import ValidationError +import pytest + +from aiod_registry import ModelManifest + + +def get_jsons(): + json_dir = Path(__file__).parent.parent / "aiod_registry" / "manifests" + print(json_dir) + return json_dir.glob("*.json") + + +@pytest.mark.parametrize("json_path", get_jsons()) +def test_manifest(json_path): + with open(json_path, "r") as f: + json_manifest = json.load(f) + try: + ModelManifest.model_validate(json_manifest) + except ValidationError as e: + raise e