Skip to content

Commit

Permalink
complete cellpose
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Sep 2, 2024
1 parent fae531e commit 46aa604
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
51 changes: 43 additions & 8 deletions bioimageio/engine/ray_apps/cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
from hypha_rpc import api
import numpy as np

class CellposeModel:
def __init__(self):
# Load model
import torch
from cellpose import core
# Check if GPU is available
self.use_GPU = core.use_gpu()
print('>>> GPU activated? %d' % self.use_GPU)

# Initialize model caching attributes
self.cached_model_type = None
self.model = None

def predict(self, image: str) -> str:
prediction = "prediction of cellpose model on image: " + image
return prediction
def _load_model(self, model_type):
from cellpose import models
if self.model is None or model_type != self.cached_model_type:
print(f'Loading model: {model_type}')
self.model = models.Cellpose(gpu=self.use_GPU, model_type=model_type)
self.cached_model_type = model_type
else:
print(f'Reusing cached model: {model_type}')
return self.model

def train(self, data: str, config: str) -> str:
training = "training cellpose model on data: " + data + "with config:" + config
return training
def predict(self, images: list[np.ndarray], channels=None, diameter=None, flow_threshold=None, model_type='cyto3'):
"""Run segmentation on the provided images using the specified model type."""
# Load the model, utilizing caching
model = self._load_model(model_type)

if channels is None:
# Default channels if not provided
channels = [[2, 3]] * len(images)

# Perform segmentation using the model
masks, flows, styles, diams = model.eval(images, diameter=diameter, flow_threshold=flow_threshold, channels=channels)

# Prepare the response with masks and diameters
results = {
'masks': [mask.tolist() for mask in masks], # Converting numpy arrays to lists for JSON serialization
'diameters': diams # List of estimated diameters for each image
}

return results

def train(self, images, labels, config):
"""Train the model using the provided images and labels."""
# This method would handle the training process.
# Currently, returning a placeholder response.
raise NotImplementedError("Training functionality not implemented yet")

# Export the CellposeModel class using Hypha RPC API
api.export(CellposeModel)
2 changes: 2 additions & 0 deletions bioimageio/engine/ray_apps/cellpose/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ ray_serve_config:
num_gpus: 1
runtime_env:
pip:
- opencv-python-headless==4.2.0.34
- cellpose==3.0.11
- torch==2.3.1
- torchvision==0.18.1
autoscaling_config:
Expand Down

0 comments on commit 46aa604

Please sign in to comment.