Skip to content

Commit

Permalink
Select subset of sims for analysis (#1250)
Browse files Browse the repository at this point in the history
* Save analysis_paths as plot class attribute from firetask
* Fix CLI help typo for analysis range options
* Pass analysis path selections to firetasks from manual scripts
* Specific args for analysis paths for each analysis type
* Remove check for == 1 plot type option in AnalysisPaths
  • Loading branch information
tahorst authored Feb 1, 2022
1 parent 4eb39c3 commit 631eee6
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 15 deletions.
29 changes: 22 additions & 7 deletions models/ecoli/analysis/AnalysisPaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class AnalysisPaths(object):
def __init__(self, out_dir, *,
variant_plot: bool = False, multi_gen_plot: bool = False,
cohort_plot: bool = False) -> None:
assert variant_plot + multi_gen_plot + cohort_plot == 1, (
"Must specify exactly one plot type!")
assert variant_plot + multi_gen_plot + cohort_plot <= 1, (
"Can only specify one analysis type!")

generation_dirs = [] # type: List[str]
if variant_plot:
Expand Down Expand Up @@ -142,11 +142,14 @@ def __init__(self, out_dir, *,
self._path_data["variantkb"] = variant_kb
self._path_data["successful"] = successful

self.n_generation = len(set(generations))
self.n_variant = len(set(variants))
self.n_seed = len(set(seeds))
self._calculate_n()

def get_cells(self, variant = None, seed = None, generation = None, only_successful=False):
def _calculate_n(self):
self.n_generation = len(set(self._path_data["generation"]))
self.n_variant = len(set(self._path_data["variant"]))
self.n_seed = len(set(self._path_data["seed"]))

def _get_cells(self, variant=None, seed=None, generation=None, only_successful=False):
# type: (Optional[Iterable[int]], Optional[Iterable[int]], Optional[Iterable[int]], bool) -> np.ndarray
"""Returns file paths for all the simulated cells matching the given
variant number, seed number, and generation number collections, where
Expand All @@ -172,7 +175,19 @@ def get_cells(self, variant = None, seed = None, generation = None, only_success
else:
successful_bool = np.ones(self._path_data.shape)

return self._path_data['path'][np.logical_and.reduce((variantBool, seedBool, generationBool, successful_bool))]
return np.logical_and.reduce((variantBool, seedBool, generationBool, successful_bool))

def get_cells(self, variant=None, seed=None, generation=None, only_successful=False):
mask = self._get_cells(variant=variant, seed=seed,
generation=generation, only_successful=only_successful)
return self._path_data['path'][mask]

def update_cells(self, variant=None, seed=None, generation=None, only_successful=False):
mask = self._get_cells(variant=variant, seed=seed,
generation=generation, only_successful=only_successful)

self._path_data = self._path_data[mask]
self._calculate_n()

def get_variant_kb(self, variant):
# type: (Union[int, str]) -> str
Expand Down
4 changes: 3 additions & 1 deletion models/ecoli/analysis/analysisPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AnalysisPlot(metaclass=abc.ABCMeta):
def __init__(self, cpus=0):
self.cpus = parallelization.cpus(cpus)
self._axeses = {}
self.ap = None

@staticmethod
def read_sim_data_file(sim_path: str) -> SimulationDataEcoli:
Expand Down Expand Up @@ -158,8 +159,9 @@ def do_plot():

@classmethod
def main(cls, inputDir, plotOutDir, plotOutFileName, simDataFile,
validationDataFile=None, metadata=None, cpus=0):
validationDataFile=None, metadata=None, cpus=0, analysis_paths=None):
"""Run an analysis plot for a Firetask."""
instance = cls(cpus)
instance.ap = analysis_paths
instance.plot(inputDir, plotOutDir, plotOutFileName, simDataFile,
validationDataFile, metadata)
37 changes: 36 additions & 1 deletion runscripts/manual/analysisBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse
import os
import sys
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from models.ecoli.analysis.analysisPlot import AnalysisPlot
from wholecell.utils import constants, data, scriptBase, parallelization
Expand Down Expand Up @@ -79,6 +79,27 @@ def select_analysis_keys(self, args):
"""Select key/value pairs specific to analysis tasks"""
return data.select_keys(vars(args), scriptBase.ANALYSIS_KEYS)

def define_path_selection(self, parser, *selections):
# type: (argparse.ArgumentParser, *str) -> None
"""
Adds options to the arg parser for path selections based on variant,
seed, generation, or successful completion of sims.
selections should take the options: ['variant', 'seed', 'generation']
"""

for selection in selections:
option = '{}_path'.format(selection)
upper = selection.upper()
parser.add_argument(scriptBase.dashize('--{}'.format(option + '_range')), nargs=2, default=None, type=int,
metavar=('START_{}'.format(upper), 'END_{}'.format(upper)),
help=f'The range of {selection} paths to include for analysis.')
parser.add_argument(scriptBase.dashize('--{}'.format(option)), nargs='*', default=None, type=int,
metavar=selection.upper(),
help=f'Specific {selection} paths to include for analysis.')

self.define_parameter_bool(parser, 'only_successful', help='Only include successful sims in analysis.')

def update_args(self, args):
# type: (argparse.Namespace) -> None
"""Update the command line args in an `argparse.Namespace`, including
Expand All @@ -91,6 +112,16 @@ def update_args(self, args):
Overrides should first call super().
"""

def analysis_paths(path, path_range):
values = set()
if arg := getattr(args, path, None):
values.update(arg)
if arg := getattr(args, path_range, None):
values.update(range(arg[0], arg[1]))

return sorted(values) if values else None

super(AnalysisBase, self).update_args(args)

if self.plot_name:
Expand All @@ -110,3 +141,7 @@ def update_args(self, args):
metadata['variant_index'] = variant_index

args.cpus = parallelization.cpus(args.cpus)

args.variant_paths = analysis_paths('variant_path', 'variant_path_range')
args.seed_paths = analysis_paths('seed_path', 'seed_path_range')
args.generation_paths = analysis_paths('generation_path', 'generation_path_range')
4 changes: 4 additions & 0 deletions runscripts/manual/analysisCohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def define_parameters(self, parser):
super(AnalysisCohort, self).define_parameters(parser)
self.define_parameter_variant_index(parser)
self.define_range_options(parser, 'variant')
self.define_path_selection(parser, 'seed', 'generation')

def run(self, args):
sim_path = args.sim_path
Expand All @@ -38,6 +39,9 @@ def run(self, args):
output_plots_directory=output_dir,
metadata=args.metadata,
output_filename_prefix=args.output_prefix,
seed_paths=args.seed_paths,
generation_paths=args.generation_paths,
only_successful=args.only_successful,
**self.select_analysis_keys(args)
)
task.run_task({})
Expand Down
3 changes: 3 additions & 0 deletions runscripts/manual/analysisMultigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def define_parameters(self, parser):
help='The initial simulation number (int). The value will get'
' formatted as a subdirectory name like "000000". Default = 0.')
self.define_range_options(parser, 'variant', 'seed')
self.define_path_selection(parser, 'generation')

def update_args(self, args):
super(AnalysisMultigen, self).update_args(args)
Expand All @@ -50,6 +51,8 @@ def run(self, args):
output_plots_directory=output_dir,
metadata=args.metadata,
output_filename_prefix=args.output_prefix,
generation_paths=args.generation_paths,
only_successful=args.only_successful,
**self.select_analysis_keys(args)
)
task.run_task({})
Expand Down
8 changes: 8 additions & 0 deletions runscripts/manual/analysisVariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
class AnalysisVariant(AnalysisBase):
"""Runs some or all the ACTIVE variant analysis plots for a given sim."""

def define_parameters(self, parser):
super().define_parameters(parser)
self.define_path_selection(parser, 'variant', 'seed', 'generation')

def update_args(self, args):
super(AnalysisVariant, self).update_args(args)

Expand All @@ -40,6 +44,10 @@ def run(self, args):
output_plots_directory=output_dir,
metadata=args.metadata,
output_filename_prefix=args.output_prefix,
variant_paths=args.variant_paths,
seed_paths=args.seed_paths,
generation_paths=args.generation_paths,
only_successful=args.only_successful,
**self.select_analysis_keys(args)
)
task.run_task({})
Expand Down
23 changes: 18 additions & 5 deletions wholecell/fireworks/firetasks/analysisBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import sys
import time
import traceback
from typing import List
from typing import Dict, List

from fireworks import FiretaskBase
import matplotlib as mpl
from PIL import Image
from six.moves import zip

from models.ecoli.analysis.AnalysisPaths import AnalysisPaths
from wholecell.utils import data
from wholecell.utils import parallelization
import wholecell.utils.filepath as fp
Expand Down Expand Up @@ -67,6 +68,8 @@ class AnalysisBase(FiretaskBase):
Optional params include plot, output_filename_prefix, cpus.
"""

analysis_path_options = {} # type: Dict[str, bool]

@abc.abstractmethod
def plotter_args(self, module_filename):
"""(Abstract) Return a tuple of arguments to pass to the analysis plot
Expand Down Expand Up @@ -167,6 +170,16 @@ def run_task(self, fw_spec):
if cpus > 1:
pool = parallelization.pool(cpus)

# Set analysis paths from args
input_dir = self.plotter_args('')[0]
variant_paths = self.get('variant_paths')
seed_paths = self.get('seed_paths')
generation_paths = self.get('generation_paths')
only_successful = self.get('only_successful', False)
analysis_paths = AnalysisPaths(input_dir, **self.analysis_path_options)
analysis_paths.update_cells(variant=variant_paths, seed=seed_paths,
generation=generation_paths, only_successful=only_successful)

exceptionFileList = []
for f in fileList:
try:
Expand All @@ -179,12 +192,12 @@ def run_task(self, fw_spec):
args = self.plotter_args(f)

if pool:
results[f] = pool.apply_async(run_plot, args=(mod.Plot, args, f))
results[f] = pool.apply_async(run_plot, args=(mod.Plot, args, f, analysis_paths))
else:
print("{}: Running {}".format(time.ctime(), f))
# noinspection PyBroadException
try:
mod.Plot.main(*args)
mod.Plot.main(*args, analysis_paths=analysis_paths)
except Exception:
traceback.print_exc()
exceptionFileList.append(f)
Expand Down Expand Up @@ -212,14 +225,14 @@ def run_task(self, fw_spec):
print('Completed analysis in {}'.format(duration))


def run_plot(plot_class, args, name):
def run_plot(plot_class, args, name, analysis_paths):
"""Run the given plot class in a Pool worker.
Since this Firetask is running multiple plot classes in parallel, ask them
to use just 1 CPU core each.
"""
try:
print("{}: Running {}".format(time.ctime(), name))
plot_class.main(*args, cpus=1)
plot_class.main(*args, cpus=1, analysis_paths=analysis_paths)
except KeyboardInterrupt:
sys.exit(1)
except Exception as e:
Expand Down
4 changes: 4 additions & 0 deletions wholecell/fireworks/firetasks/analysisCohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ class AnalysisCohortTask(AnalysisBase):
"output_filename_prefix",
"cpus",
"compile",
"seed_paths",
"generation_paths",
"only_successful",
]
MODULE_PATH = 'models.ecoli.analysis.cohort'
TAGS = models.ecoli.analysis.cohort.TAGS
analysis_path_options = {'cohort_plot': True}

def plotter_args(self, module_filename):
self["metadata"] = dict(self["metadata"], analysis_type = "cohort")
Expand Down
3 changes: 3 additions & 0 deletions wholecell/fireworks/firetasks/analysisMultiGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ class AnalysisMultiGenTask(AnalysisBase):
"output_filename_prefix",
"cpus",
"compile",
"generation_paths",
"only_successful",
]
MODULE_PATH = 'models.ecoli.analysis.multigen'
TAGS = models.ecoli.analysis.multigen.TAGS
analysis_path_options = {'multi_gen_plot': True}

def plotter_args(self, module_filename):
self["metadata"] = dict(self["metadata"], analysis_type = "multigen")
Expand Down
5 changes: 5 additions & 0 deletions wholecell/fireworks/firetasks/analysisVariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ class AnalysisVariantTask(AnalysisBase):
"output_filename_prefix",
"cpus",
"compile",
"variant_paths",
"seed_paths",
"generation_paths",
"only_successful",
]
MODULE_PATH = 'models.ecoli.analysis.variant'
TAGS = models.ecoli.analysis.variant.TAGS
analysis_path_options = {'variant_plot': True}

def plotter_args(self, module_filename):
self["metadata"] = dict(self["metadata"], analysis_type = "variant")
Expand Down
2 changes: 1 addition & 1 deletion wholecell/utils/scriptBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def define_range_options(self, parser, *range_keys):
override = dashize('--{}'.format(RANGE_ARGS[option]))
parser.add_argument(dashize('--{}'.format(option)), nargs=2, default=None, type=int,
metavar=('START_{}'.format(upper), 'END_{}'.format(upper)),
help='The range of variants to run. Will override {} option.'.format(override))
help=f'The range of {key}s to run. Will override {override} option.')
self.range_options.append(option)

def parse_args(self):
Expand Down

0 comments on commit 631eee6

Please sign in to comment.