diff --git a/bioimageio/engine/__main__.py b/bioimageio/engine/__main__.py index afd6c67..d12b5af 100644 --- a/bioimageio/engine/__main__.py +++ b/bioimageio/engine/__main__.py @@ -1,10 +1,12 @@ +import os import sys import secrets import asyncio import subprocess +import ray +from hypha_rpc import connect_to_server from hypha_launcher.api import HyphaLauncher from hypha_launcher.utils.log import get_logger - import fire logger = get_logger() @@ -35,15 +37,100 @@ async def start_server(host = "0.0.0.0", port = 9000, public_base_url = "", laun def connect_server(server_url, login_required=False): from engine import connect_server - server_url = server_url loop = asyncio.get_event_loop() loop.create_task(connect_server(server_url)) loop.run_forever() -def serve_ray_apps(): - from ray import serve - from bioimageio.engine.ray_app_loader import app - serve.run(app) +async def _run_ray_server_apps(address, ready_timeout): + if not address: + address = os.environ.get("RAY_ADDRESS") + with ray.init(address=address) as client_context: + dashboard_url = f"http://{client_context.dashboard_url}" + logger.info(f"Dashboard URL: {dashboard_url}") + server_url = os.environ.get("HYPHA_SERVER_URL") + workspace = os.environ.get("HYPHA_WORKSPACE") + token = os.environ.get("HYPHA_TOKEN") + assert server_url, "HYPHA_SERVER_URL environment variable is not set" + health_check_url = f"{server_url}/{workspace}/services/ray-apps" + logger.info(f"Health check URL: {health_check_url}") + shutdown_command = "serve shutdown -y" + (f" --address={dashboard_url}") + server = await connect_to_server({"server_url": server_url, "token": token, "workspace": workspace}) + serve_command = f"serve run bioimageio.engine.ray_app_manager:app" + (f" --address={address}" if address else "") + + proc = None + from simpervisor import SupervisedProcess + import httpx + + async def ready_function(_): + try: + async with httpx.AsyncClient() as client: + response = await client.get(health_check_url) + return response.status_code == 200 + except Exception as e: + logger.warning(f"Error checking readiness: {str(e)}") + return False + + async def serve_apps(): + nonlocal proc + if proc: + if await proc.ready(): + return f"Ray Apps Manager already started at {server_url}/{workspace}/services/{svc.id.split('/')[1]}" + if proc: + await proc.kill() + command = [c.strip() for c in serve_command.split() if c.strip()] + name = "ray-apps-manager" + proc = SupervisedProcess( + name, + *command, + env=os.environ.copy(), + always_restart=False, + ready_func=ready_function, + ready_timeout=ready_timeout, + log=logger, + ) + + try: + await proc.start() + + is_ready = await proc.ready() + + if not is_ready and proc and proc.running: + await proc.kill() + raise Exception(f"External services ({name}) failed to start") + except: + if logger: + logger.exception(f"External services ({name}) failed to start") + raise + return f"Ray Apps Manager started at {server_url}/{workspace}/services/{svc.id.split('/')[1]}" + + async def shutdown(): + nonlocal proc + if proc: + if proc.running: + await proc.kill() + proc = None + else: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, os.system, shutdown_command) + + svc = await server.register_service({ + "name": "Ray App Manager", + "id": "ray-app-manager", + "config": { + "visibility": "public", + "run_in_executor": True, + }, + "serve": serve_apps, + "shutdown": shutdown + }) + + # Start apps serve + await serve_apps() + +def serve_ray_apps(address: str=None, ready_timeout: int=120): + loop = asyncio.get_event_loop() + loop.create_task(_run_ray_server_apps(address, ready_timeout)) + loop.run_forever() if __name__ == '__main__': fire.Fire({ diff --git a/bioimageio/engine/ray_app_loader.py b/bioimageio/engine/ray_app_loader.py index 435bb5d..d70ffb9 100644 --- a/bioimageio/engine/ray_app_loader.py +++ b/bioimageio/engine/ray_app_loader.py @@ -1,5 +1,4 @@ -"""Provide main entrypoint.""" -import asyncio +"""Provide ray app manager.""" import re from ray import serve import logging @@ -8,13 +7,12 @@ import sys from pathlib import Path import urllib.request -from starlette.requests import Request from hypha_rpc.utils import ObjectProxy -from hypha_rpc.sync import connect_to_server + logging.basicConfig(stream=sys.stdout) -logger = logging.getLogger("app_launcher") +logger = logging.getLogger("ray_app_manager") logger.setLevel(logging.INFO) @@ -34,110 +32,80 @@ def load_app(app_file, manifest): if app_file.endswith(".py"): app_info = ObjectProxy.fromDict(manifest) import hypha_rpc + def export(app_class): # make sure app_class is a class, not an instance if not isinstance(app_class, type): raise RuntimeError("exported object must be a class") app_info.app_class = app_class + # expose the app class to the current module + globals()[app_class.__name__] = app_class + hypha_rpc.api = ObjectProxy(export=export) - exec(content, globals()) # pylint: disable=exec-used + exec(content, globals()) logger.info(f"App loaded: {app_info.name}") # load manifest file if exists return app_info else: raise RuntimeError(f"Invalid script file type ({app_file})") - -def load_all_apps(work_dir): +def create_ray_serve_config(app_info): + ray_serve_config = app_info.get( + "ray_serve_config", {"ray_actor_options": {"runtime_env": {}}} + ) + assert ( + "ray_actor_options" in ray_serve_config + ), "ray_actor_options must be provided in ray_serve_config" + assert ( + "runtime_env" in ray_serve_config["ray_actor_options"] + ), "runtime_env must be provided in ray_actor_options" + runtime_env = ray_serve_config["ray_actor_options"]["runtime_env"] + if not runtime_env.get("pip"): + runtime_env["pip"] = ["hypha-rpc"] + else: + if "hypha-rpc" not in runtime_env["pip"]: + runtime_env["pip"].append("hypha-rpc") + runtime_env["pip"].append( + "https://github.com/bioimage-io/bioengine/archive/refs/heads/main.zip" + ) + return ray_serve_config + +def load_all_apps() -> dict: + current_dir = Path(os.path.dirname(os.path.realpath(__file__))) + apps_dir = current_dir / "ray_apps" ray_apps = {} - apps_dir = work_dir / "ray_apps" for sub_dir in apps_dir.iterdir(): - # check the subfolder for apps - # there should be a file named "manifest.yaml" in the subfolder - # if yes, load the app - # by parsing the manifest.yaml file first, - # find the entrypoint key with the file path - # set it to app_file if sub_dir.is_dir(): manifest_file = sub_dir / "manifest.yaml" if manifest_file.is_file(): with open(manifest_file, "r") as f: manifest = yaml.safe_load(f) - + # make sure the app_id is in lower case, no spaces, only underscores, letters, and numbers pattern = r"^[a-z0-9_]*$" - assert re.match(pattern, manifest["id"]), "App ID must be in lower case, no spaces, only underscores, letters, and numbers" + assert re.match( + pattern, manifest["id"] + ), "App ID must be in lower case, no spaces, only underscores, letters, and numbers" assert manifest["runtime"] == "ray", "Only ray apps are supported" app_file = sub_dir / manifest["entrypoint"] + + app_info = load_app(str(app_file), manifest) + ray_serve_config = create_ray_serve_config(app_info) + # runtime_env["env_vars"] = dict(os.environ) + app_deployment = serve.deployment( + name=app_info.id, **ray_serve_config + )(app_info.app_class).bind() + app_info.app_bind = app_deployment + app_info.methods = [ + m for m in dir(app_info.app_class) if not m.startswith("_") + ] + ray_apps[app_info.id] = app_info - if app_file.is_file() and app_file.suffix == ".py": - app_info = load_app(str(app_file), manifest) - ray_serve_config = manifest.get("ray_serve_config", {}) - app_deployment = serve.deployment(name=app_info.id, **ray_serve_config)(app_info.app_class).bind() - manifest["app_bind"] = app_deployment - manifest["methods"] = [m for m in dir(app_info.app_class) if not m.startswith("_")] - ray_apps[app_info.id] = manifest - print("Loaded apps:", ray_apps.keys()) assert len(ray_apps) > 0, "No apps loaded" return ray_apps -@serve.deployment -class HyphaRayAppManager: - def __init__(self, server_url, workspace, token, ray_apps): - self.server_url = server_url - self._apps = ray_apps - self._hypha_server = connect_to_server({"server_url": server_url, "token": token, "workspace": workspace}) - - def create_service_function(name, app_bind, method_name): - async def service_function(*args, **kwargs): - method = getattr(app_bind, method_name) - return await method.remote(*args, **kwargs) - service_function.__name__ = name - return service_function - - for app_id, app_info in self._apps.items(): - app_bind = app_info["app_bind"] - methods = app_info["methods"] - app_service = { - "id": app_id, - "name": app_info["name"], - "description": app_info["description"], - "config":{ - "visibility": "protected" - }, - } - for method in methods: - print(f"Registering method {method} for app {app_id}") - app_service[method] = create_service_function(method, app_bind, method) - info = self._hypha_server.register_service(app_service, {"overwrite":True}) - print(f"Added service {app_id} with id {info.id}, use it at {self.server_url}/{workspace}/services/{info.id.split('/')[1]}") - - async def __call__(self, request: Request): - # return a json object with the services - services = {} - for app_id, app_info in self._apps.items(): - services[app_id] = app_info["methods"] - return services - - -current_dir = Path(os.path.dirname(os.path.realpath(__file__))) -ray_apps = load_all_apps(current_dir) - -# Getting config from environment -server_url = os.environ.get("HYPHA_SERVER_URL") -workspace = os.environ.get("HYPHA_WORKSPACE") -token = os.environ.get("HYPHA_TOKEN") - -assert server_url, "Server URL is not provided" - -app = HyphaRayAppManager.bind(server_url, workspace, token, ray_apps) - -if __name__ == "__main__": - serve.start() - serve.run(app, name="bioengine-apps") - import asyncio - asyncio.get_event_loop().run_forever() +ray_apps = load_all_apps() \ No newline at end of file diff --git a/bioimageio/engine/ray_app_manager.py b/bioimageio/engine/ray_app_manager.py new file mode 100644 index 0000000..70eb3ef --- /dev/null +++ b/bioimageio/engine/ray_app_manager.py @@ -0,0 +1,154 @@ +"""Provide ray app loader.""" + +import asyncio +from ray import serve +import logging +import os +import sys +from starlette.requests import Request +from bioimageio.engine.ray_app_loader import ray_apps + +logging.basicConfig(stream=sys.stdout) +logger = logging.getLogger("ray_app_launcher") +logger.setLevel(logging.INFO) + + +@serve.deployment( + ray_actor_options={ + "runtime_env": { + "pip": [ + "hypha-rpc", + "https://github.com/bioimage-io/bioengine/archive/refs/heads/main.zip", + ] + } + } +) +class HyphaRayAppManager: + def __init__(self, server_url, workspace, token, ray_apps): + from hypha_rpc.sync import connect_to_server + + self.server_url = server_url + self._apps = ray_apps + self._ongoing_requests = {} # Track ongoing requests per app and method + self._scale_down_flags = {} # Flags to mark apps for scaling down + + assert server_url, "Server URL is required" + self._hypha_server = connect_to_server( + {"server_url": server_url, "token": token, "workspace": workspace} + ) + + def create_service_function(app_id, method_name, app_bind): + key = f"{app_id}:{method_name}" + self._ongoing_requests[key] = 0 # Initialize counter + self._scale_down_flags[app_id] = False # Initialize scale down flag + + async def service_function(*args, **kwargs): + # Mark other apps to scale down if they are not the current app + self.mark_apps_for_scaling_down(app_id) + + # Track the start of a request + self._ongoing_requests[key] += 1 + logger.info( + f"Starting request for {key}, ongoing: {self._ongoing_requests[key]}" + ) + + try: + method = getattr(app_bind, method_name) + results = await method.remote(*args, **kwargs) + return results + except Exception as e: + # Log the error and raise it + logger.error(f"Error in {key}: {str(e)}") + raise + finally: + # Track the end of a request + self._ongoing_requests[key] -= 1 + logger.info( + f"Completed request for {key}, ongoing: {self._ongoing_requests[key]}" + ) + + # If no ongoing requests and flag is set, scale down + if ( + self._ongoing_requests[key] == 0 + and self._scale_down_flags[app_id] + ): + self.scale_down_if_idle(app_id) + + service_function.__name__ = method_name + return service_function + + for app_id, app_info in self._apps.items(): + app_bind = app_info["app_bind"] + methods = app_info["methods"] + app_service = { + "id": app_id, + "name": app_info["name"], + "description": app_info["description"], + "config": {"visibility": "public"}, + } + if app_info.get("service_config"): + svc_config = app_info["service_config"] + app_service["config"].update(svc_config) + + for method in methods: + logger.info(f"Registering method {method} for app {app_id}") + app_service[method] = create_service_function(app_id, method, app_bind) + info = self._hypha_server.register_service(app_service, {"overwrite": True}) + logger.info( + f"Added service {app_id} with id {info.id}, use it at {self.server_url}/{workspace}/services/{info.id.split('/')[1]}" + ) + + apps = {} + for app_id, app_info in self._apps.items(): + apps[app_id] = app_info["methods"] + + info = self._hypha_server.register_service( + { + "id": "ray-apps", + "name": app_info["name"], + "description": app_info["description"], + "config": {"visibility": "public"}, + "apps": apps, + }, + {"overwrite": True}, + ) + logger.info(f"Registered Ray Apps service with id {info.id}") + + def mark_apps_for_scaling_down(self, current_app_id): + # Iterate through all apps to set scale down flags + for app_id, app_info in self._apps.items(): + if app_id != current_app_id: # Skip the current app + for method_name in app_info["methods"]: + key = f"{app_id}:{method_name}" + if self._ongoing_requests.get(key, 0) > 0: + self._scale_down_flags[app_id] = True + logger.info( + f"Marked {app_id} for scaling down after completion" + ) + + def scale_down_if_idle(self, app_id): + pass + # deployment_handle = serve.get_deployment_handle(app_id) + # deployment_handle.options(num_replicas=0).deploy() + # logger.info(f"Scaled down deployment {app_id} due to inactivity") + + async def __call__(self, request: Request): + # Return a JSON object with the services + services = {} + for app_id, app_info in self._apps.items(): + services[app_id] = app_info["methods"] + return services + + +# Environment variables for connecting to the Hypha server +server_url = os.environ.get("HYPHA_SERVER_URL") +workspace = os.environ.get("HYPHA_WORKSPACE") +token = os.environ.get("HYPHA_TOKEN") + +# Bind the deployment +app = HyphaRayAppManager.bind(server_url, workspace, token, ray_apps) + +if __name__ == "__main__": + serve.start() + serve.run(app) + asyncio.get_event_loop().run_forever() diff --git a/bioimageio/engine/ray_apps/cellpose/__init__.py b/bioimageio/engine/ray_apps/cellpose/__init__.py index d6ed7a7..a96ce04 100644 --- a/bioimageio/engine/ray_apps/cellpose/__init__.py +++ b/bioimageio/engine/ray_apps/cellpose/__init__.py @@ -3,7 +3,7 @@ class CellposeModel: def __init__(self): # Load model - pass + import torch def predict(self, image: str) -> str: prediction = "prediction of cellpose model on image: " + image diff --git a/bioimageio/engine/ray_apps/cellpose/manifest.yaml b/bioimageio/engine/ray_apps/cellpose/manifest.yaml index 5f1d394..c8769cf 100644 --- a/bioimageio/engine/ray_apps/cellpose/manifest.yaml +++ b/bioimageio/engine/ray_apps/cellpose/manifest.yaml @@ -5,10 +5,12 @@ runtime: ray entrypoint: __init__.py ray_serve_config: ray_actor_options: + num_gpus: 1 runtime_env: pip: - - numpy - - torch + - torch==2.3.1 + - torchvision==0.18.1 autoscaling_config: + downscale_delay_s: 1 min_replicas: 0 max_replicas: 2 \ No newline at end of file diff --git a/bioimageio/engine/ray_apps/micro_sam/__init__.py b/bioimageio/engine/ray_apps/micro_sam/__init__.py new file mode 100644 index 0000000..d1384bb --- /dev/null +++ b/bioimageio/engine/ray_apps/micro_sam/__init__.py @@ -0,0 +1,155 @@ +import io +from logging import getLogger +from typing import Union +from hypha_rpc import api +import numpy as np + +class MicroSAM: + def __init__(self, model_timeout: int = 3600, embedding_timeout: int = 600): + from cachetools import TTLCache + + # Set up logger + self.logger = getLogger(__name__) + self.logger.setLevel("INFO") + # Define model URLs + self.model_urls = { + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", + "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", + } + # Set up cache with per-item time-to-live + self.models = TTLCache( + maxsize=len(self.model_urls), ttl=model_timeout + ) # TODO: what if multiple users download the same model? + self.embeddings = TTLCache(maxsize=np.inf, ttl=embedding_timeout) + + def _load_model(self, model_name: str): + import torch + import requests + from segment_anything import sam_model_registry + + if model_name not in self.model_urls: + raise ValueError( + f"Model {model_name} not found. Available models: {list(self.model_urls.keys())}" + ) + # Check cache first + if model_name in self.models: + return self.models[model_name] + + # Download model if not in cache (takes approx. 4 seconds) + model_url = self.model_urls[model_name] + self.logger.info(f"Loading model {model_name} from {model_url}...") + response = requests.get(model_url) + if response.status_code != 200: + raise RuntimeError(f"Failed to download model from {model_url}") + buffer = io.BytesIO(response.content) + + # Load model state + device = "cuda" if torch.cuda.is_available() else "cpu" + ckpt = torch.load(buffer, map_location=device) + model_type = model_name[:5] + sam = sam_model_registry[model_type]() + sam.load_state_dict(ckpt) + + # Cache the model + self.logger.info(f"Caching model {model_name} (device={device})...") + self.models[model_name] = sam + + return sam + + def _to_image(self, input_): + + # we require the input to be uint8 + if input_.dtype != np.dtype("uint8"): + # first normalize the input to [0, 1] + input_ = input_.astype("float32") - input_.min() + input_ = input_ / input_.max() + # then bring to [0, 255] and cast to uint8 + input_ = (input_ * 255).astype("uint8") + if input_.ndim == 2: + image = np.concatenate([input_[..., None]] * 3, axis=-1) + elif input_.ndim == 3 and input_.shape[-1] == 3: + image = input_ + else: + raise ValueError( + f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image." + ) + return image + + def compute_embedding( + self, model_name: str, image: np.ndarray, context: dict = None + ) -> bool: + from segment_anything import SamPredictor + + user_id = context["user"].get("id") + if not user_id: + self.logger.info("User ID not found in context.") + return False + sam = self._load_model(model_name) + self.logger.info(f"User {user_id} - computing embedding...") + predictor = SamPredictor(sam) + predictor.set_image(self._to_image(image)) + # Save computed predictor values + self.logger.info(f"User {user_id} - caching embedding...") + predictor_dict = { + "model_name": model_name, + "original_size": predictor.original_size, + "input_size": predictor.input_size, + "features": predictor.features, # embedding + "is_image_set": predictor.is_image_set, + } + self.embeddings[user_id] = predictor_dict + return True + + def reset_embedding(self, context: dict = None) -> bool: + user_id = context["user"].get("id") + if user_id not in self.embeddings: + self.logger.info(f"User {user_id} not found in cache.") + return False + else: + self.logger.info(f"User {user_id} - resetting embedding...") + del self.embeddings[user_id] + return True + + def segment( + self, + point_coordinates: Union[list, np.ndarray], + point_labels: Union[list, np.ndarray], + context: dict = None, + ) -> list: + from kaibu_utils import mask_to_features + from segment_anything import SamPredictor + user_id = context["user"].get("id") + if user_id not in self.embeddings: + self.logger.info(f"User {user_id} not found in cache.") + return [] + self.logger.info( + f"User {user_id} - segmenting with model {self.embeddings[user_id]['model_name']}..." + ) + # Load the model with the pre-computed embedding + sam = self._load_model(self.embeddings[user_id]["model_name"]) + predictor = SamPredictor(sam) + for key, value in self.embeddings[user_id].items(): + if key != "model_name": + setattr(predictor, key, value) + # Run the segmentation + self.logger.debug( + f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}" + ) + if isinstance(point_coordinates, list): + point_coordinates = np.array(point_coordinates, dtype=np.float32) + if isinstance(point_labels, list): + point_labels = np.array(point_labels, dtype=np.float32) + mask, scores, logits = predictor.predict( + point_coords=point_coordinates[ + :, ::-1 + ], # SAM has reversed XY conventions + point_labels=point_labels, + multimask_output=False, + ) + self.logger.debug(f"User {user_id} - predicted mask of shape {mask.shape}") + features = mask_to_features(mask[0]) + return features + + +api.export(MicroSAM) diff --git a/bioimageio/engine/ray_apps/micro_sam/manifest.yaml b/bioimageio/engine/ray_apps/micro_sam/manifest.yaml new file mode 100644 index 0000000..bd061d0 --- /dev/null +++ b/bioimageio/engine/ray_apps/micro_sam/manifest.yaml @@ -0,0 +1,23 @@ +name: microSAM +id: micro_sam +description: Segment Anything for Microscopy implements automatic and interactive annotation for microscopy data. +runtime: ray +entrypoint: __init__.py +service_config: + require_context: true +ray_serve_config: + ray_actor_options: + num_gpus: 1 + runtime_env: + pip: + - cachetools==5.5.0 + - kaibu-utils==0.1.14 + - numpy==1.26.4 + - requests==2.31.0 + - segment_anything==1.0 + - torch==2.3.1 + - torchvision==0.18.1 + autoscaling_config: + downscale_delay_s: 1 + min_replicas: 0 + max_replicas: 2 \ No newline at end of file diff --git a/bioimageio/engine/ray_apps/translator/manifest.yaml b/bioimageio/engine/ray_apps/translator/manifest.yaml index cb83c73..a19458c 100644 --- a/bioimageio/engine/ray_apps/translator/manifest.yaml +++ b/bioimageio/engine/ray_apps/translator/manifest.yaml @@ -5,9 +5,13 @@ runtime: ray entrypoint: translator.py ray_serve_config: ray_actor_options: + num_gpus: 1 runtime_env: pip: + - torch==2.3.1 + - torchvision==0.18.1 - transformers autoscaling_config: + downscale_delay_s: 1 min_replicas: 0 max_replicas: 2 \ No newline at end of file diff --git a/bioimageio/engine/ray_apps/translator/translator.py b/bioimageio/engine/ray_apps/translator/translator.py index 56d9397..62f78fe 100644 --- a/bioimageio/engine/ray_apps/translator/translator.py +++ b/bioimageio/engine/ray_apps/translator/translator.py @@ -1,9 +1,9 @@ from hypha_rpc import api -from transformers import pipeline class Translator: def __init__(self): + from transformers import pipeline # Load model self.model = pipeline("translation_en_to_fr", model="t5-small") diff --git a/tests/test_micro_sam.py b/tests/test_micro_sam.py new file mode 100644 index 0000000..3f36e98 --- /dev/null +++ b/tests/test_micro_sam.py @@ -0,0 +1,26 @@ +from hypha_rpc.sync import connect_to_server +import numpy as np + + +def test_get_service( + server_url: str="https://hypha.aicell.io", + workspace_name: str="bioengine-apps", + service_id: str="micro_sam", + ): + client = connect_to_server({"server_url": server_url, "method_timeout": 5}) + assert client + + sid = f"{workspace_name}/{service_id}" + segment_svc = client.get_service(sid) + assert segment_svc.config.workspace == workspace_name + assert segment_svc.get("compute_embedding") + assert segment_svc.get("segment") + assert segment_svc.get("reset_embedding") + + assert segment_svc.compute_embedding("vit_b", np.random.rand(256, 256)) + features = segment_svc.segment([[128, 128]], [1]) + assert features + assert segment_svc.reset_embedding() + +if __name__ == "__main__": + test_get_service() \ No newline at end of file