diff --git a/mesmerize_core/algorithms/_utils.py b/mesmerize_core/algorithms/_utils.py new file mode 100644 index 0000000..e7dbdca --- /dev/null +++ b/mesmerize_core/algorithms/_utils.py @@ -0,0 +1,48 @@ +import caiman as cm +from contextlib import contextmanager +from ipyparallel import DirectView +from multiprocessing.pool import Pool +import os +import psutil +from typing import Union, Optional, Generator + +Cluster = Union[Pool, DirectView] + +def get_n_processes(dview: Optional[Cluster]) -> int: + """Infer number of processes in a multiprocessing or ipyparallel cluster""" + if isinstance(dview, Pool) and hasattr(dview, '_processes'): + return dview._processes + elif isinstance(dview, DirectView): + return len(dview) + else: + return 1 + + +@contextmanager +def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]: + """ + Context manager that passes through an existing 'dview' or + opens up a multiprocessing server if none is passed in. + If a server was opened, closes it upon exit. + Usage: `with ensure_server(dview) as (dview, n_processes):` + """ + if dview is not None: + yield dview, get_n_processes(dview) + else: + # no cluster passed in, so open one + if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + try: + n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) + except: + n_processes = psutil.cpu_count() - 1 + else: + n_processes = psutil.cpu_count() - 1 + + # Start cluster for parallel processing + _, dview, n_processes = cm.cluster.setup_cluster( + backend="multiprocessing", n_processes=n_processes, single_thread=False + ) + try: + yield dview, n_processes + finally: + cm.stop_server(dview=dview) diff --git a/mesmerize_core/algorithms/cnmf.py b/mesmerize_core/algorithms/cnmf.py index 4c71174..ea23626 100644 --- a/mesmerize_core/algorithms/cnmf.py +++ b/mesmerize_core/algorithms/cnmf.py @@ -3,24 +3,24 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil import numpy as np import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -41,102 +41,84 @@ def run_algo(batch_path, uuid, data_path: str = None): f"Starting CNMF item:\n{item}\nWith params:{params}" ) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): - try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) + with ensure_server(dview) as (dview, n_processes): - # merge cnmf and eval kwargs into one dict - cnmf_params = CNMFParams(params_dict=params["main"]) - # Run CNMF, denote boolean 'success' if CNMF completes w/out error - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) + # merge cnmf and eval kwargs into one dict + cnmf_params = CNMFParams(params_dict=params["main"]) + # Run CNMF, denote boolean 'success' if CNMF completes w/out error + try: + fname_new = cm.save_memmap( + [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview + ) - print("making memmap") + print("making memmap") + + Yr, dims, T = cm.load_memmap(fname_new) + + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + print("performing CNMF") + cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) + + print("fitting images") + cnm = cnm.fit(images) + # + if "refit" in params.keys(): + if params["refit"] is True: + print("refitting") + cnm = cnm.refit(images, dview=dview) + + print("performing eval") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) + + output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve() + + cnm.save(str(output_path)) + + Cn = cm.local_correlations(images.transpose(1, 2, 0)) + Cn[np.isnan(Cn)] = 0 + + corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve() + np.save(str(corr_img_path), Cn, allow_pickle=False) + + # output dict for dataframe row (pd.Series) + d = dict() + + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise + move_file(fname_new, cnmf_memmap_path) + + # save paths as relative path strings with forward slashes + cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) + cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))) + corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to( + output_dir.parent + ))) + + d.update( + { + "cnmf-hdf5-path": cnmf_hdf5_path, + "cnmf-memmap-path": cnmf_memmap_path, + "corr-img-path": corr_img_path, + "success": True, + "traceback": None, + } + ) - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") + except: + d = {"success": False, "traceback": traceback.format_exc()} - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) - - # in fname new load in memmap order C - cm.stop_server(dview=dview) - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=None, single_thread=False - ) - - print("performing CNMF") - cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) - - print("fitting images") - cnm = cnm.fit(images) - # - if "refit" in params.keys(): - if params["refit"] is True: - print("refitting") - cnm = cnm.refit(images, dview=dview) - - print("performing eval") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - - output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve() - - cnm.save(str(output_path)) - - Cn = cm.local_correlations(images.transpose(1, 2, 0)) - Cn[np.isnan(Cn)] = 0 - - corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve() - np.save(str(corr_img_path), Cn, allow_pickle=False) - - # output dict for dataframe row (pd.Series) - d = dict() - - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save paths as relative path strings with forward slashes - cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) - cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))) - corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to( - output_dir.parent - ))) - - d.update( - { - "cnmf-hdf5-path": cnmf_hdf5_path, - "cnmf-memmap-path": cnmf_memmap_path, - "corr-img-path": corr_img_path, - "success": True, - "traceback": None, - } - ) - - except: - d = {"success": False, "traceback": traceback.format_exc()} - - cm.stop_server(dview=dview) - runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/cnmfe.py b/mesmerize_core/algorithms/cnmfe.py index e053869..1d8c601 100644 --- a/mesmerize_core/algorithms/cnmfe.py +++ b/mesmerize_core/algorithms/cnmfe.py @@ -3,22 +3,22 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -35,91 +35,77 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] print("cnmfe params:", params) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) - - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) - - print("making memmap") - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - # TODO: if projections already exist from mcorr we don't - # need to waste compute time re-computing them here - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) - - d = dict() # for output - - # force the CNMFE params - cnmfe_params_dict = { - "method_init": "corr_pnr", - "n_processes": n_processes, - "only_init": True, # for 1p - "center_psf": True, # for 1p - "normalize_init": False, # for 1p - } - - params_dict = {**cnmfe_params_dict, **params["main"]} - - cnmfe_params_dict = CNMFParams(params_dict=params_dict) - cnm = cnmf.CNMF( - n_processes=n_processes, dview=dview, params=cnmfe_params_dict - ) - print("Performing CNMFE") - cnm = cnm.fit(images) - print("evaluating components") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - - cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5").resolve() - cnm.save(str(cnmf_hdf5_path)) - - # save output paths to outputs dict - d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) - - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( - output_dir.parent + fname_new = cm.save_memmap( + [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview ) - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save path as relative path strings with forward slashes - cnmfe_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))) - - d.update( - { - "cnmf-memmap-path": cnmfe_memmap_path, - "success": True, - "traceback": None, + print("making memmap") + Yr, dims, T = cm.load_memmap(fname_new) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + # TODO: if projections already exist from mcorr we don't + # need to waste compute time re-computing them here + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + d = dict() # for output + + # force the CNMFE params + cnmfe_params_dict = { + "method_init": "corr_pnr", + "n_processes": n_processes, + "only_init": True, # for 1p + "center_psf": True, # for 1p + "normalize_init": False, # for 1p } - ) - except: - d = {"success": False, "traceback": traceback.format_exc()} + params_dict = {**cnmfe_params_dict, **params["main"]} - cm.stop_server(dview=dview) + cnmfe_params_dict = CNMFParams(params_dict=params_dict) + cnm = cnmf.CNMF( + n_processes=n_processes, dview=dview, params=cnmfe_params_dict + ) + print("Performing CNMFE") + cnm = cnm.fit(images) + print("evaluating components") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) + + cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5").resolve() + cnm.save(str(cnmf_hdf5_path)) + + # save output paths to outputs dict + d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) + + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( + output_dir.parent + ) + + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise + move_file(fname_new, cnmf_memmap_path) + + # save path as relative path strings with forward slashes + cnmfe_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))) + + d.update( + { + "cnmf-memmap-path": cnmfe_memmap_path, + "success": True, + "traceback": None, + } + ) + + except: + d = {"success": False, "traceback": traceback.format_exc()} runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/mcorr.py b/mesmerize_core/algorithms/mcorr.py index 755f697..860c58f 100644 --- a/mesmerize_core/algorithms/mcorr.py +++ b/mesmerize_core/algorithms/mcorr.py @@ -4,7 +4,6 @@ from caiman.source_extraction.cnmf.params import CNMFParams from caiman.motion_correction import MotionCorrect from caiman.summary_images import local_correlations_movie_offline -import psutil import os from pathlib import Path, PurePosixPath import numpy as np @@ -14,11 +13,13 @@ # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -39,111 +40,97 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): + print("starting mc") + + rel_params = dict(params["main"]) + opts = CNMFParams(params_dict=rel_params) + # Run MC, denote boolean 'success' if MC completes w/out error try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - - print("starting mc") - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) - - rel_params = dict(params["main"]) - opts = CNMFParams(params_dict=rel_params) - # Run MC, denote boolean 'success' if MC completes w/out error - try: - # Run MC - fnames = [input_movie_path] - mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) - mc.motion_correct(save_movie=True) - - # find path to mmap file - memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) - - # filename to move the output back to data dir - mcorr_memmap_path = output_dir.joinpath( - f"{uuid}-{memmap_output_path_temp.name}" - ) - - # move the output file - move_file(memmap_output_path_temp, mcorr_memmap_path) - - print("mc finished successfully!") - - print("computing projections") - Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" + # Run MC + fnames = [input_movie_path] + mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) + mc.motion_correct(save_movie=True) + + # find path to mmap file + memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) + + # filename to move the output back to data dir + mcorr_memmap_path = output_dir.joinpath( + f"{uuid}-{memmap_output_path_temp.name}" ) - np.save(str(proj_paths[proj_type]), p_img) - - print("Computing correlation image") - Cns = local_correlations_movie_offline( - [str(mcorr_memmap_path)], - remove_baseline=True, - window=1000, - stride=1000, - winSize_baseline=100, - quantil_min_baseline=10, - dview=dview, - ) - Cn = Cns.max(axis=0) - Cn[np.isnan(Cn)] = 0 - cn_path = output_dir.joinpath(f"{uuid}_cn.npy") - np.save(str(cn_path), Cn, allow_pickle=False) - - # output dict for pandas series for dataframe row - d = dict() - - print("finished computing correlation image") - - # Compute shifts - if opts.motion["pw_rigid"] == True: - x_shifts = mc.x_shifts_els - y_shifts = mc.y_shifts_els - shifts = [x_shifts, y_shifts] - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - else: - shifts = mc.shifts_rig - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - - # save paths as relative path strings with forward slashes - cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) - mcorr_memmap_path = str(PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent))) - shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to( - output_dir.parent - ))) - - d.update( - { - "mcorr-output-path": mcorr_memmap_path, - "corr-img-path": cn_path, - "shifts": shift_path, - "success": True, - "traceback": None, - } - ) - - except: - d = {"success": False, "traceback": traceback.format_exc()} - print("mc failed, stored traceback in output") - - cm.stop_server(dview=dview) + + # move the output file + move_file(memmap_output_path_temp, mcorr_memmap_path) + + print("mc finished successfully!") + + print("computing projections") + Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + print("Computing correlation image") + Cns = local_correlations_movie_offline( + [str(mcorr_memmap_path)], + remove_baseline=True, + window=1000, + stride=1000, + winSize_baseline=100, + quantil_min_baseline=10, + dview=dview, + ) + Cn = Cns.max(axis=0) + Cn[np.isnan(Cn)] = 0 + cn_path = output_dir.joinpath(f"{uuid}_cn.npy") + np.save(str(cn_path), Cn, allow_pickle=False) + + # output dict for pandas series for dataframe row + d = dict() + + print("finished computing correlation image") + + # Compute shifts + if opts.motion["pw_rigid"] == True: + x_shifts = mc.x_shifts_els + y_shifts = mc.y_shifts_els + shifts = [x_shifts, y_shifts] + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + else: + shifts = mc.shifts_rig + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + + # save paths as relative path strings with forward slashes + cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) + mcorr_memmap_path = str(PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent))) + shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to( + output_dir.parent + ))) + + d.update( + { + "mcorr-output-path": mcorr_memmap_path, + "corr-img-path": cn_path, + "shifts": shift_path, + "success": True, + "traceback": None, + } + ) + + except: + d = {"success": False, "traceback": traceback.format_exc()} + print("mc failed, stored traceback in output") runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/batch_utils.py b/mesmerize_core/batch_utils.py index 800a288..4d5c1cd 100644 --- a/mesmerize_core/batch_utils.py +++ b/mesmerize_core/batch_utils.py @@ -13,8 +13,9 @@ COMPUTE_BACKEND_SUBPROCESS = "subprocess" #: subprocess backend COMPUTE_BACKEND_SLURM = "slurm" #: SLURM backend COMPUTE_BACKEND_LOCAL = "local" +COMPUTE_BACKEND_ASYNC = "local_async" -COMPUTE_BACKENDS = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_SLURM, COMPUTE_BACKEND_LOCAL] +COMPUTE_BACKENDS = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_SLURM, COMPUTE_BACKEND_LOCAL, COMPUTE_BACKEND_ASYNC] DATAFRAME_COLUMNS = ["algo", "item_name", "input_movie_path", "params", "outputs", "added_time", "ran_time", "algo_duration", "comments", "uuid"] diff --git a/mesmerize_core/caiman_extensions/common.py b/mesmerize_core/caiman_extensions/common.py index bd78fb6..462fdad 100644 --- a/mesmerize_core/caiman_extensions/common.py +++ b/mesmerize_core/caiman_extensions/common.py @@ -10,6 +10,7 @@ import time from copy import deepcopy import shlex +from concurrent.futures import ThreadPoolExecutor, Future import numpy as np import pandas as pd @@ -21,6 +22,7 @@ COMPUTE_BACKENDS, COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, get_parent_raw_data_path, load_batch, ) @@ -458,12 +460,26 @@ def get_parent(self, index: Union[int, str, UUID]) -> Union[UUID, None]: return r["uuid"] -class DummyProcess: +class Waitable(Protocol): + """An object that we can call "wait" on""" + def wait(self) -> None: ... + + +class DummyProcess(Waitable): """Dummy process for local backend""" - def wait(self): + def wait(self) -> None: pass +class WaitableFuture(Waitable): + """Adaptor for future returned from Executor.submit""" + def __init__(self, future: Future[None]): + self.future = future + + def wait(self) -> None: + return self.future.result() + + @pd.api.extensions.register_series_accessor("caiman") class CaimanSeriesExtensions: """ @@ -472,7 +488,7 @@ class CaimanSeriesExtensions: def __init__(self, s: pd.Series): self._series = s - self.process: Popen = None + self.process: Optional[Waitable] = None def _run_local( self, @@ -480,15 +496,37 @@ def _run_local( batch_path: Path, uuid: UUID, data_path: Union[Path, None], - ): + dview=None + ) -> DummyProcess: algo_module = getattr(algorithms, algo) algo_module.run_algo( batch_path=str(batch_path), uuid=str(uuid), - data_path=str(data_path) + data_path=str(data_path), + dview=dview ) + self.process = DummyProcess() + return self.process - return DummyProcess() + def _run_local_async( + self, + algo: str, + batch_path: Path, + uuid: UUID, + data_path: Union[Path, None], + dview=None + ) -> WaitableFuture: + algo_module = getattr(algorithms, algo) + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + algo_module.run_algo, + batch_path=str(batch_path), + uuid=str(uuid), + data_path=str(data_path), + dview=dview + ) + self.process = WaitableFuture(future) + return self.process def _run_subprocess( self, @@ -599,13 +637,14 @@ def run( batch_path = self._series.paths.get_batch_path() - if backend == COMPUTE_BACKEND_LOCAL: - print(f"Running {self._series.uuid} with local backend") - return self._run_local( + if backend in [COMPUTE_BACKEND_LOCAL, COMPUTE_BACKEND_ASYNC]: + print(f"Running {self._series.uuid} with {backend} backend") + return getattr(self, f"_run_{backend}")( algo=self._series["algo"], batch_path=batch_path, uuid=self._series["uuid"], data_path=get_parent_raw_data_path(), + dview=kwargs.get("dview") ) # Create the runfile in the batch dir using this Series' UUID as the filename diff --git a/tests/test_core.py b/tests/test_core.py index 163189e..4e10012 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,4 @@ import os - import numpy as np from caiman.utils.utils import load_dict_from_hdf5 from caiman.source_extraction.cnmf import cnmf @@ -12,8 +11,14 @@ CaimanSeriesExtensions, set_parent_raw_data_path, ) -from mesmerize_core.batch_utils import DATAFRAME_COLUMNS, COMPUTE_BACKEND_SUBPROCESS, get_full_raw_data_path +from mesmerize_core.batch_utils import ( + DATAFRAME_COLUMNS, + COMPUTE_BACKEND_SUBPROCESS, + COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, + get_full_raw_data_path) from mesmerize_core.utils import IS_WINDOWS +from mesmerize_core.algorithms._utils import ensure_server from uuid import uuid4 from typing import * import pytest @@ -30,6 +35,8 @@ import tifffile from copy import deepcopy +pytest_plugins = ('pytest_asyncio',) + tmp_dir = Path(os.path.dirname(os.path.abspath(__file__)), "tmp") vid_dir = Path(os.path.dirname(os.path.abspath(__file__)), "videos") ground_truths_dir = Path(os.path.dirname(os.path.abspath(__file__)), "ground_truths") @@ -1254,3 +1261,48 @@ def test_cache(): output2 = df.iloc[1].cnmf.get_output(return_copy=False) assert(hex(id(output)) == hex(id(output2))) assert(hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex(id(output))) + + +def test_backends(): + """test subprocess, local, and async_local backend""" + set_parent_raw_data_path(vid_dir) + algo = "mcorr" + df, batch_path = _create_tmp_batch() + input_movie_path = get_datafile(algo) + + # make small version of movie for quick testing + movie = tifffile.imread(input_movie_path) + small_movie_path = input_movie_path.parent.joinpath("small_movie.tif") + tifffile.imwrite(small_movie_path, movie[:1001]) + print(input_movie_path) + + # put backends that can run in the background first to save time + backends = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_ASYNC, COMPUTE_BACKEND_LOCAL] + for backend in backends: + df.caiman.add_item( + algo="mcorr", + item_name=f"test-{backend}", + input_movie_path=small_movie_path, + params=test_params["mcorr"], + ) + + # run using each backend + procs = [] + with ensure_server(None) as (dview, _): + for backend, (_, item) in zip(backends, df.iterrows()): + procs.append(item.caiman.run(backend=backend, dview=dview, wait=False)) + + # wait for all to finish + for proc in procs: + proc.wait() + + # compare results + df = load_batch(batch_path) + for i, item in df.iterrows(): + output = item.mcorr.get_output() + + if i == 0: + # save to compare to other results + first_output = output + else: + numpy.testing.assert_array_equal(output, first_output)