diff --git a/apps/imagej/main.py b/apps/imagej/main.py index 690a748..18c2ba0 100644 --- a/apps/imagej/main.py +++ b/apps/imagej/main.py @@ -1,211 +1,14 @@ import logging import sys, os -import imagej -import scyjava as sj import asyncio -import traceback -import numpy as np -import xarray as xr -from jpype import JOverride, JImplements from hypha.utils import launch_external_services -import argparse -from imjoy_rpc.hypha import connect_to_server - - -logger = logging.getLogger(__name__) -os.environ["JAVA_HOME"] = os.sep.join(sys.executable.split(os.sep)[:-2] + ["jre"]) - -def capture_console(ij, print=True): - logs = {} - logs["stdout"] = [] - logs["stderr"] = [] - - @JImplements("org.scijava.console.OutputListener") - class JavaOutputListener: - @JOverride - def outputOccurred(self, e): - source = e.getSource().toString - output = e.getOutput() - - if print: - if source == "STDOUT": - sys.stdout.write(output) - logs["stdout"].append(output) - elif source == "STDERR": - sys.stderr.write(output) - logs["stderr"].append(output) - else: - output = "[{}] {}".format(source, output) - sys.stderr.write(output) - logs["stderr"].append(output) - - ij.py._outputMapper = JavaOutputListener() - ij.console().addOutputListener(ij.py._outputMapper) - return logs - - -def format_logs(logs): - output = "" - if logs["stdout"]: - output += "STDOUT:\n" - output += "\n".join(logs["stdout"]) - output += "\n" - if logs["stderr"]: - output += "STDERR:\n" - output += "\n".join(logs["stderr"]) - output += "\n" - return output - - -def get_module_info(ij, custom_script, name=None): - name = name or "scijava_script" - ScriptInfo = sj.jimport("org.scijava.script.ScriptInfo") - StringReader = sj.jimport("java.io.StringReader") - moduleinfo = ScriptInfo(ij.getContext(), name, StringReader(custom_script)) - inputs = {} - outputs = {} - - for inp in ij.py.from_java(moduleinfo.inputs()): - input_type = str(inp.getType().getName()) - input_name = str(inp.getName()) - print(input_type, input_name) - inputs[input_name] = {"name": input_name, "type": input_type} - - for outp in ij.py.from_java(moduleinfo.outputs()): - output_type = str(outp.getType().getName()) - output_name = str(outp.getName()) - outputs[output_name] = {"name": output_name, "type": output_type} - - return {"id": moduleinfo.getIdentifier(), "outputs": outputs, "inputs": inputs} - - -def check_size(array): - result_bytes = array.tobytes() - if len(result_bytes) > 20000000: # 20MB - raise Exception( - f"The data is too large ({len(result_bytes)} bytes) to be transfered." - ) +from pathlib import Path -async def execute(config, context=None): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, run_imagej, config) - - -def run_imagej(config): - mode = config.get("mode", "headless") - logger.info("Initializing ImageJ...") - ij = imagej.init(os.environ.get("IMAGEJ_DIR"), mode=mode) - logger.info("Running ImageJ macro...") - try: - WindowManager = sj.jimport("ij.WindowManager") - ImagePlus = sj.jimport("ij.ImagePlus") - logs = capture_console(ij) - script = config.get("script") - lang = config.get("lang", "ijm") - assert script is not None, "script is required" - module_info = get_module_info(ij, script) - inputs_info = module_info["inputs"] - outputs_info = module_info["outputs"] - inputs = config.get("inputs", {}) - select_outputs = config.get("select_outputs") - args = {} - for k in inputs: - if isinstance(inputs[k], (np.ndarray, np.generic, dict)): - if isinstance(inputs[k], (np.ndarray, np.generic)): - if inputs[k].ndim == 2: - dims = ["x", "y"] - elif inputs[k].ndim == 3 and inputs[k].shape[2] in [1, 3, 4]: - dims = ["x", "y", "c"] - elif inputs[k].ndim == 3 and inputs[k].shape[0] in [1, 3, 4]: - dims = ["c", "x", "y"] - elif inputs[k].ndim == 3: - dims = ["z", "x", "y"] - elif inputs[k].ndim == 4: - dims = ["z", "x", "y", "c"] - elif inputs[k].ndim == 5: - dims = ["t", "z", "x", "y", "c"] - else: - raise Exception(f"Unsupported ndim: {inputs[k].ndim}") - inputs[k] = {"data": inputs[k], "dims": dims} - - img = inputs[k] - assert isinstance( - img, dict - ), f"input {k} must be a dictionary or a numpy array" - assert "data" in img, f"data is required for {k}" - assert "dims" in img, f"dims is required for {k}" - da = xr.DataArray( - data=img["data"], - dims=img["dims"], - attrs=img.get("attrs", {}), - name=k, - ) - inputs[k] = ij.py.to_java(da) - if lang == "ijm": - # convert to ImagePlus - inputs[k] = ij.convert().convert(inputs[k], ImagePlus) - if inputs[k]: - inputs[k].setTitle(k) - # Display the image - if mode != "headless": - inputs[k].show() - else: - raise NotImplementedError( - "Don't know how to display the image (only ijm is supported)." - ) - if k in inputs_info: - args[k] = ij.py.to_java(inputs[k]) - - # Run the script - macro_result = ij.py.run_script(lang, script, args) - results = {} - if select_outputs is None: - select_outputs = list(outputs_info.keys()) - for k in select_outputs: - if k in outputs_info: - results[k] = macro_result.getOutput(k) - if results[k] and not isinstance(results[k], (int, str, float, bool)): - try: - results[k] = ij.py.from_java(results[k]).to_numpy() - check_size(results[k]) - except Exception: - # TODO: This is needed due to a bug in pyimagej for converting java string - if str(type(results[k])) == "": - results[k] = str(results[k]) - else: - results[k] = { - "type": str(type(results[k])), - "text": str(results[k]), - } - else: - # If the output name is not in the script annotation, - # Try to get the image from the WindowManager by title - img = WindowManager.getImage(k) - if not img: - raise Exception(f"Output not found: {k}\n{format_logs(logs)}") - results[k] = ij.py.from_java(img).to_numpy() - check_size(results[k]) - except Exception as exp: - raise exp - finally: - ij.dispose() - - return {"outputs": results, "logs": logs} - - - -test_macro = """ -#@ String name -#@ int age -#@ String city -#@output Object greeting -greeting = "Hi " + name + ". You are " + age + " years old, and live in " + city + "." -""" - +logger = logging.getLogger(__name__) async def hypha_startup(server): - file_path = os.path.abspath(__file__) + file_path = Path(os.path.abspath(__file__)).parent / "run_imagej.py" server_url = server.config.local_base_url workspace = server.config.workspace token = await server.generate_token() @@ -221,52 +24,4 @@ async def hypha_startup(server): ) logger.info("ImageJ service is ready.") - -async def main(): - - parser = argparse.ArgumentParser() - parser.add_argument("--server-url", type=str, help="Server URL") - parser.add_argument("--workspace", type=str, help="Workspace") - parser.add_argument("--token", type=str, help="Token") - parser.add_argument("--service-id", type=str, help="ImageJ Service ID") - args = parser.parse_args() - logger.info("Connecting to the server %s, workspace: %s", args.server_url, args.workspace) - # get server_url and token from url - server = await connect_to_server({ - "server_url": args.server_url, - "workspace": args.workspace, - "token": args.token, - }) - try: - logger.info("Testing the imagej service...") - ret = await execute( - { - "script": test_macro, - "inputs": {"name": "Tom", "age": 20, "city": "Shanghai"}, - } - ) - outputs = ret["outputs"] - assert ( - outputs["greeting"] == "Hi Tom. You are 20 years old, and live in Shanghai." - ) - except Exception: - print(traceback.format_exc()) - sys.exit(1) - - logger.info("Starting the imagej service...") - svc = await server.register_service( - { - "id": args.service_id, - "type": "imagej", - "config": {"require_context": True, "visibility": "public"}, - "execute": execute, - } - ) - logger.info(f"ImageJ service is registered as `{svc['id']}`") - -if __name__ == "__main__": - loop = asyncio.get_event_loop() - loop.create_task(main()) - loop.run_forever() - \ No newline at end of file diff --git a/apps/imagej/run_imagej.py b/apps/imagej/run_imagej.py new file mode 100644 index 0000000..eb73023 --- /dev/null +++ b/apps/imagej/run_imagej.py @@ -0,0 +1,251 @@ +import logging +import sys, os +import imagej +import scyjava as sj +import asyncio +import traceback +import numpy as np +import xarray as xr +from jpype import JOverride, JImplements +import argparse +from imjoy_rpc.hypha import connect_to_server + + +logger = logging.getLogger(__name__) +os.environ["JAVA_HOME"] = os.sep.join(sys.executable.split(os.sep)[:-2] + ["jre"]) + +def capture_console(ij, print=True): + logs = {} + logs["stdout"] = [] + logs["stderr"] = [] + + @JImplements("org.scijava.console.OutputListener") + class JavaOutputListener: + @JOverride + def outputOccurred(self, e): + source = e.getSource().toString + output = e.getOutput() + + if print: + if source == "STDOUT": + sys.stdout.write(output) + logs["stdout"].append(output) + elif source == "STDERR": + sys.stderr.write(output) + logs["stderr"].append(output) + else: + output = "[{}] {}".format(source, output) + sys.stderr.write(output) + logs["stderr"].append(output) + + ij.py._outputMapper = JavaOutputListener() + ij.console().addOutputListener(ij.py._outputMapper) + return logs + + +def format_logs(logs): + output = "" + if logs["stdout"]: + output += "STDOUT:\n" + output += "\n".join(logs["stdout"]) + output += "\n" + if logs["stderr"]: + output += "STDERR:\n" + output += "\n".join(logs["stderr"]) + output += "\n" + return output + + +def get_module_info(ij, custom_script, name=None): + name = name or "scijava_script" + ScriptInfo = sj.jimport("org.scijava.script.ScriptInfo") + StringReader = sj.jimport("java.io.StringReader") + moduleinfo = ScriptInfo(ij.getContext(), name, StringReader(custom_script)) + inputs = {} + outputs = {} + + for inp in ij.py.from_java(moduleinfo.inputs()): + input_type = str(inp.getType().getName()) + input_name = str(inp.getName()) + print(input_type, input_name) + inputs[input_name] = {"name": input_name, "type": input_type} + + for outp in ij.py.from_java(moduleinfo.outputs()): + output_type = str(outp.getType().getName()) + output_name = str(outp.getName()) + outputs[output_name] = {"name": output_name, "type": output_type} + + return {"id": moduleinfo.getIdentifier(), "outputs": outputs, "inputs": inputs} + + +def check_size(array): + result_bytes = array.tobytes() + if len(result_bytes) > 20000000: # 20MB + raise Exception( + f"The data is too large ({len(result_bytes)} bytes) to be transfered." + ) + + +async def execute(config, context=None): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, run_imagej, config) + + +def run_imagej(config): + mode = config.get("mode", "headless") + logger.info("Initializing ImageJ...") + ij = imagej.init(os.environ.get("IMAGEJ_DIR"), mode=mode) + logger.info("Running ImageJ macro...") + try: + WindowManager = sj.jimport("ij.WindowManager") + ImagePlus = sj.jimport("ij.ImagePlus") + logs = capture_console(ij) + script = config.get("script") + lang = config.get("lang", "ijm") + assert script is not None, "script is required" + module_info = get_module_info(ij, script) + inputs_info = module_info["inputs"] + outputs_info = module_info["outputs"] + inputs = config.get("inputs", {}) + select_outputs = config.get("select_outputs") + args = {} + for k in inputs: + if isinstance(inputs[k], (np.ndarray, np.generic, dict)): + if isinstance(inputs[k], (np.ndarray, np.generic)): + if inputs[k].ndim == 2: + dims = ["x", "y"] + elif inputs[k].ndim == 3 and inputs[k].shape[2] in [1, 3, 4]: + dims = ["x", "y", "c"] + elif inputs[k].ndim == 3 and inputs[k].shape[0] in [1, 3, 4]: + dims = ["c", "x", "y"] + elif inputs[k].ndim == 3: + dims = ["z", "x", "y"] + elif inputs[k].ndim == 4: + dims = ["z", "x", "y", "c"] + elif inputs[k].ndim == 5: + dims = ["t", "z", "x", "y", "c"] + else: + raise Exception(f"Unsupported ndim: {inputs[k].ndim}") + inputs[k] = {"data": inputs[k], "dims": dims} + + img = inputs[k] + assert isinstance( + img, dict + ), f"input {k} must be a dictionary or a numpy array" + assert "data" in img, f"data is required for {k}" + assert "dims" in img, f"dims is required for {k}" + da = xr.DataArray( + data=img["data"], + dims=img["dims"], + attrs=img.get("attrs", {}), + name=k, + ) + inputs[k] = ij.py.to_java(da) + if lang == "ijm": + # convert to ImagePlus + inputs[k] = ij.convert().convert(inputs[k], ImagePlus) + if inputs[k]: + inputs[k].setTitle(k) + # Display the image + if mode != "headless": + inputs[k].show() + else: + raise NotImplementedError( + "Don't know how to display the image (only ijm is supported)." + ) + if k in inputs_info: + args[k] = ij.py.to_java(inputs[k]) + + # Run the script + macro_result = ij.py.run_script(lang, script, args) + results = {} + if select_outputs is None: + select_outputs = list(outputs_info.keys()) + for k in select_outputs: + if k in outputs_info: + results[k] = macro_result.getOutput(k) + if results[k] and not isinstance(results[k], (int, str, float, bool)): + try: + results[k] = ij.py.from_java(results[k]).to_numpy() + check_size(results[k]) + except Exception: + # TODO: This is needed due to a bug in pyimagej for converting java string + if str(type(results[k])) == "": + results[k] = str(results[k]) + else: + results[k] = { + "type": str(type(results[k])), + "text": str(results[k]), + } + else: + # If the output name is not in the script annotation, + # Try to get the image from the WindowManager by title + img = WindowManager.getImage(k) + if not img: + raise Exception(f"Output not found: {k}\n{format_logs(logs)}") + results[k] = ij.py.from_java(img).to_numpy() + check_size(results[k]) + except Exception as exp: + raise exp + finally: + ij.dispose() + + return {"outputs": results, "logs": logs} + + +test_macro = """ +#@ String name +#@ int age +#@ String city +#@output Object greeting +greeting = "Hi " + name + ". You are " + age + " years old, and live in " + city + "." +""" + + +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--server-url", type=str, help="Server URL") + parser.add_argument("--workspace", type=str, help="Workspace") + parser.add_argument("--token", type=str, help="Token") + parser.add_argument("--service-id", type=str, help="ImageJ Service ID") + + args = parser.parse_args() + logger.info("Connecting to the server %s, workspace: %s", args.server_url, args.workspace) + # get server_url and token from url + server = await connect_to_server({ + "server_url": args.server_url, + "workspace": args.workspace, + "token": args.token, + }) + try: + logger.info("Testing the imagej service...") + ret = await execute( + { + "script": test_macro, + "inputs": {"name": "Tom", "age": 20, "city": "Shanghai"}, + } + ) + outputs = ret["outputs"] + assert ( + outputs["greeting"] == "Hi Tom. You are 20 years old, and live in Shanghai." + ) + except Exception: + print(traceback.format_exc()) + sys.exit(1) + + logger.info("Starting the imagej service...") + svc = await server.register_service( + { + "id": args.service_id, + "type": "imagej", + "config": {"require_context": True, "visibility": "public"}, + "execute": execute, + } + ) + logger.info(f"ImageJ service is registered as `{svc['id']}`") + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.create_task(main()) + loop.run_forever() + diff --git a/docs/README.md b/docs/README.md index 2e5b9cb..e00ed95 100644 --- a/docs/README.md +++ b/docs/README.md @@ -66,6 +66,13 @@ To get started with BioEngine, please see our [tutorial for I2K 2023](https://sl Please read the documentation at: https://bioimage-io.github.io/bioengine/ -## BioEngine Tutorial - * [Tutorial for I2K 2023](https://slides.imjoy.io/?slides=https://raw.githubusercontent.com/bioimage-io/BioEngine/main/slides/i2k-2023-bioengine-workshop.md) - +## TODO + ++ Runtime types support via hypha-launcher: + * [ ] HPC: Slurm / PBS / LFS ... + * [ ] Conda environment + * [ ] Docker / Apptainer / podman ... + * [ ] Web Browser + * [ ] pytriton(python package) + * [ ] SSH + X(other runtime types) + * [ ] K8S