diff --git a/models/ecoli/analysis/AnalysisPaths.py b/models/ecoli/analysis/AnalysisPaths.py index 1a860cec59..e43c14ec60 100644 --- a/models/ecoli/analysis/AnalysisPaths.py +++ b/models/ecoli/analysis/AnalysisPaths.py @@ -42,8 +42,8 @@ class AnalysisPaths(object): def __init__(self, out_dir, *, variant_plot: bool = False, multi_gen_plot: bool = False, cohort_plot: bool = False) -> None: - assert variant_plot + multi_gen_plot + cohort_plot == 1, ( - "Must specify exactly one plot type!") + assert variant_plot + multi_gen_plot + cohort_plot <= 1, ( + "Can only specify one analysis type!") generation_dirs = [] # type: List[str] if variant_plot: @@ -142,11 +142,14 @@ def __init__(self, out_dir, *, self._path_data["variantkb"] = variant_kb self._path_data["successful"] = successful - self.n_generation = len(set(generations)) - self.n_variant = len(set(variants)) - self.n_seed = len(set(seeds)) + self._calculate_n() - def get_cells(self, variant = None, seed = None, generation = None, only_successful=False): + def _calculate_n(self): + self.n_generation = len(set(self._path_data["generation"])) + self.n_variant = len(set(self._path_data["variant"])) + self.n_seed = len(set(self._path_data["seed"])) + + def _get_cells(self, variant=None, seed=None, generation=None, only_successful=False): # type: (Optional[Iterable[int]], Optional[Iterable[int]], Optional[Iterable[int]], bool) -> np.ndarray """Returns file paths for all the simulated cells matching the given variant number, seed number, and generation number collections, where @@ -172,7 +175,19 @@ def get_cells(self, variant = None, seed = None, generation = None, only_success else: successful_bool = np.ones(self._path_data.shape) - return self._path_data['path'][np.logical_and.reduce((variantBool, seedBool, generationBool, successful_bool))] + return np.logical_and.reduce((variantBool, seedBool, generationBool, successful_bool)) + + def get_cells(self, variant=None, seed=None, generation=None, only_successful=False): + mask = self._get_cells(variant=variant, seed=seed, + generation=generation, only_successful=only_successful) + return self._path_data['path'][mask] + + def update_cells(self, variant=None, seed=None, generation=None, only_successful=False): + mask = self._get_cells(variant=variant, seed=seed, + generation=generation, only_successful=only_successful) + + self._path_data = self._path_data[mask] + self._calculate_n() def get_variant_kb(self, variant): # type: (Union[int, str]) -> str diff --git a/models/ecoli/analysis/analysisPlot.py b/models/ecoli/analysis/analysisPlot.py index 87d840c944..126c98e558 100644 --- a/models/ecoli/analysis/analysisPlot.py +++ b/models/ecoli/analysis/analysisPlot.py @@ -47,6 +47,7 @@ class AnalysisPlot(metaclass=abc.ABCMeta): def __init__(self, cpus=0): self.cpus = parallelization.cpus(cpus) self._axeses = {} + self.ap = None @staticmethod def read_sim_data_file(sim_path: str) -> SimulationDataEcoli: @@ -158,8 +159,9 @@ def do_plot(): @classmethod def main(cls, inputDir, plotOutDir, plotOutFileName, simDataFile, - validationDataFile=None, metadata=None, cpus=0): + validationDataFile=None, metadata=None, cpus=0, analysis_paths=None): """Run an analysis plot for a Firetask.""" instance = cls(cpus) + instance.ap = analysis_paths instance.plot(inputDir, plotOutDir, plotOutFileName, simDataFile, validationDataFile, metadata) diff --git a/runscripts/manual/analysisBase.py b/runscripts/manual/analysisBase.py index 544de1fece..cde7827933 100644 --- a/runscripts/manual/analysisBase.py +++ b/runscripts/manual/analysisBase.py @@ -9,7 +9,7 @@ import argparse import os import sys -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from models.ecoli.analysis.analysisPlot import AnalysisPlot from wholecell.utils import constants, data, scriptBase, parallelization @@ -79,6 +79,27 @@ def select_analysis_keys(self, args): """Select key/value pairs specific to analysis tasks""" return data.select_keys(vars(args), scriptBase.ANALYSIS_KEYS) + def define_path_selection(self, parser, *selections): + # type: (argparse.ArgumentParser, *str) -> None + """ + Adds options to the arg parser for path selections based on variant, + seed, generation, or successful completion of sims. + + selections should take the options: ['variant', 'seed', 'generation'] + """ + + for selection in selections: + option = '{}_path'.format(selection) + upper = selection.upper() + parser.add_argument(scriptBase.dashize('--{}'.format(option + '_range')), nargs=2, default=None, type=int, + metavar=('START_{}'.format(upper), 'END_{}'.format(upper)), + help=f'The range of {selection} paths to include for analysis.') + parser.add_argument(scriptBase.dashize('--{}'.format(option)), nargs='*', default=None, type=int, + metavar=selection.upper(), + help=f'Specific {selection} paths to include for analysis.') + + self.define_parameter_bool(parser, 'only_successful', help='Only include successful sims in analysis.') + def update_args(self, args): # type: (argparse.Namespace) -> None """Update the command line args in an `argparse.Namespace`, including @@ -91,6 +112,16 @@ def update_args(self, args): Overrides should first call super(). """ + + def analysis_paths(path, path_range): + values = set() + if arg := getattr(args, path, None): + values.update(arg) + if arg := getattr(args, path_range, None): + values.update(range(arg[0], arg[1])) + + return sorted(values) if values else None + super(AnalysisBase, self).update_args(args) if self.plot_name: @@ -110,3 +141,7 @@ def update_args(self, args): metadata['variant_index'] = variant_index args.cpus = parallelization.cpus(args.cpus) + + args.variant_paths = analysis_paths('variant_path', 'variant_path_range') + args.seed_paths = analysis_paths('seed_path', 'seed_path_range') + args.generation_paths = analysis_paths('generation_path', 'generation_path_range') diff --git a/runscripts/manual/analysisCohort.py b/runscripts/manual/analysisCohort.py index 60450e79da..37d94f05a2 100644 --- a/runscripts/manual/analysisCohort.py +++ b/runscripts/manual/analysisCohort.py @@ -20,6 +20,7 @@ def define_parameters(self, parser): super(AnalysisCohort, self).define_parameters(parser) self.define_parameter_variant_index(parser) self.define_range_options(parser, 'variant') + self.define_path_selection(parser, 'seed', 'generation') def run(self, args): sim_path = args.sim_path @@ -38,6 +39,9 @@ def run(self, args): output_plots_directory=output_dir, metadata=args.metadata, output_filename_prefix=args.output_prefix, + seed_paths=args.seed_paths, + generation_paths=args.generation_paths, + only_successful=args.only_successful, **self.select_analysis_keys(args) ) task.run_task({}) diff --git a/runscripts/manual/analysisMultigen.py b/runscripts/manual/analysisMultigen.py index 3d01c82e62..b68c468772 100644 --- a/runscripts/manual/analysisMultigen.py +++ b/runscripts/manual/analysisMultigen.py @@ -24,6 +24,7 @@ def define_parameters(self, parser): help='The initial simulation number (int). The value will get' ' formatted as a subdirectory name like "000000". Default = 0.') self.define_range_options(parser, 'variant', 'seed') + self.define_path_selection(parser, 'generation') def update_args(self, args): super(AnalysisMultigen, self).update_args(args) @@ -50,6 +51,8 @@ def run(self, args): output_plots_directory=output_dir, metadata=args.metadata, output_filename_prefix=args.output_prefix, + generation_paths=args.generation_paths, + only_successful=args.only_successful, **self.select_analysis_keys(args) ) task.run_task({}) diff --git a/runscripts/manual/analysisVariant.py b/runscripts/manual/analysisVariant.py index 2b0d810745..a0fccd76b8 100644 --- a/runscripts/manual/analysisVariant.py +++ b/runscripts/manual/analysisVariant.py @@ -17,6 +17,10 @@ class AnalysisVariant(AnalysisBase): """Runs some or all the ACTIVE variant analysis plots for a given sim.""" + def define_parameters(self, parser): + super().define_parameters(parser) + self.define_path_selection(parser, 'variant', 'seed', 'generation') + def update_args(self, args): super(AnalysisVariant, self).update_args(args) @@ -40,6 +44,10 @@ def run(self, args): output_plots_directory=output_dir, metadata=args.metadata, output_filename_prefix=args.output_prefix, + variant_paths=args.variant_paths, + seed_paths=args.seed_paths, + generation_paths=args.generation_paths, + only_successful=args.only_successful, **self.select_analysis_keys(args) ) task.run_task({}) diff --git a/wholecell/fireworks/firetasks/analysisBase.py b/wholecell/fireworks/firetasks/analysisBase.py index 69d634e3a4..5d8580287b 100644 --- a/wholecell/fireworks/firetasks/analysisBase.py +++ b/wholecell/fireworks/firetasks/analysisBase.py @@ -14,13 +14,14 @@ import sys import time import traceback -from typing import List +from typing import Dict, List from fireworks import FiretaskBase import matplotlib as mpl from PIL import Image from six.moves import zip +from models.ecoli.analysis.AnalysisPaths import AnalysisPaths from wholecell.utils import data from wholecell.utils import parallelization import wholecell.utils.filepath as fp @@ -67,6 +68,8 @@ class AnalysisBase(FiretaskBase): Optional params include plot, output_filename_prefix, cpus. """ + analysis_path_options = {} # type: Dict[str, bool] + @abc.abstractmethod def plotter_args(self, module_filename): """(Abstract) Return a tuple of arguments to pass to the analysis plot @@ -167,6 +170,16 @@ def run_task(self, fw_spec): if cpus > 1: pool = parallelization.pool(cpus) + # Set analysis paths from args + input_dir = self.plotter_args('')[0] + variant_paths = self.get('variant_paths') + seed_paths = self.get('seed_paths') + generation_paths = self.get('generation_paths') + only_successful = self.get('only_successful', False) + analysis_paths = AnalysisPaths(input_dir, **self.analysis_path_options) + analysis_paths.update_cells(variant=variant_paths, seed=seed_paths, + generation=generation_paths, only_successful=only_successful) + exceptionFileList = [] for f in fileList: try: @@ -179,12 +192,12 @@ def run_task(self, fw_spec): args = self.plotter_args(f) if pool: - results[f] = pool.apply_async(run_plot, args=(mod.Plot, args, f)) + results[f] = pool.apply_async(run_plot, args=(mod.Plot, args, f, analysis_paths)) else: print("{}: Running {}".format(time.ctime(), f)) # noinspection PyBroadException try: - mod.Plot.main(*args) + mod.Plot.main(*args, analysis_paths=analysis_paths) except Exception: traceback.print_exc() exceptionFileList.append(f) @@ -212,14 +225,14 @@ def run_task(self, fw_spec): print('Completed analysis in {}'.format(duration)) -def run_plot(plot_class, args, name): +def run_plot(plot_class, args, name, analysis_paths): """Run the given plot class in a Pool worker. Since this Firetask is running multiple plot classes in parallel, ask them to use just 1 CPU core each. """ try: print("{}: Running {}".format(time.ctime(), name)) - plot_class.main(*args, cpus=1) + plot_class.main(*args, cpus=1, analysis_paths=analysis_paths) except KeyboardInterrupt: sys.exit(1) except Exception as e: diff --git a/wholecell/fireworks/firetasks/analysisCohort.py b/wholecell/fireworks/firetasks/analysisCohort.py index f691e56772..0a07a87190 100644 --- a/wholecell/fireworks/firetasks/analysisCohort.py +++ b/wholecell/fireworks/firetasks/analysisCohort.py @@ -29,9 +29,13 @@ class AnalysisCohortTask(AnalysisBase): "output_filename_prefix", "cpus", "compile", + "seed_paths", + "generation_paths", + "only_successful", ] MODULE_PATH = 'models.ecoli.analysis.cohort' TAGS = models.ecoli.analysis.cohort.TAGS + analysis_path_options = {'cohort_plot': True} def plotter_args(self, module_filename): self["metadata"] = dict(self["metadata"], analysis_type = "cohort") diff --git a/wholecell/fireworks/firetasks/analysisMultiGen.py b/wholecell/fireworks/firetasks/analysisMultiGen.py index 20e59e8b99..dd4b0e1908 100644 --- a/wholecell/fireworks/firetasks/analysisMultiGen.py +++ b/wholecell/fireworks/firetasks/analysisMultiGen.py @@ -29,9 +29,12 @@ class AnalysisMultiGenTask(AnalysisBase): "output_filename_prefix", "cpus", "compile", + "generation_paths", + "only_successful", ] MODULE_PATH = 'models.ecoli.analysis.multigen' TAGS = models.ecoli.analysis.multigen.TAGS + analysis_path_options = {'multi_gen_plot': True} def plotter_args(self, module_filename): self["metadata"] = dict(self["metadata"], analysis_type = "multigen") diff --git a/wholecell/fireworks/firetasks/analysisVariant.py b/wholecell/fireworks/firetasks/analysisVariant.py index c33361210b..f48b2e9591 100644 --- a/wholecell/fireworks/firetasks/analysisVariant.py +++ b/wholecell/fireworks/firetasks/analysisVariant.py @@ -30,9 +30,14 @@ class AnalysisVariantTask(AnalysisBase): "output_filename_prefix", "cpus", "compile", + "variant_paths", + "seed_paths", + "generation_paths", + "only_successful", ] MODULE_PATH = 'models.ecoli.analysis.variant' TAGS = models.ecoli.analysis.variant.TAGS + analysis_path_options = {'variant_plot': True} def plotter_args(self, module_filename): self["metadata"] = dict(self["metadata"], analysis_type = "variant") diff --git a/wholecell/utils/scriptBase.py b/wholecell/utils/scriptBase.py index 3835dd42f8..067766a880 100644 --- a/wholecell/utils/scriptBase.py +++ b/wholecell/utils/scriptBase.py @@ -532,7 +532,7 @@ def define_range_options(self, parser, *range_keys): override = dashize('--{}'.format(RANGE_ARGS[option])) parser.add_argument(dashize('--{}'.format(option)), nargs=2, default=None, type=int, metavar=('START_{}'.format(upper), 'END_{}'.format(upper)), - help='The range of variants to run. Will override {} option.'.format(override)) + help=f'The range of {key}s to run. Will override {override} option.') self.range_options.append(option) def parse_args(self):