diff --git a/src/xia2/Modules/MultiCrystal/cluster_analysis.py b/src/xia2/Modules/MultiCrystal/cluster_analysis.py index cced2b03b..27646bd19 100644 --- a/src/xia2/Modules/MultiCrystal/cluster_analysis.py +++ b/src/xia2/Modules/MultiCrystal/cluster_analysis.py @@ -4,7 +4,11 @@ import logging import pathlib +import iotbx.phil +from dials.algorithms.correlation.analysis import CorrelationMatrix +from dials.algorithms.correlation.cluster import ClusterInfo from dials.array_family import flex +from dxtbx.model import ExperimentList from xia2.Modules.MultiCrystalAnalysis import MultiCrystalAnalysis @@ -65,7 +69,11 @@ """ -def clusters_and_types(cos_angle_clusters, cc_clusters, methods): +def clusters_and_types( + cos_angle_clusters: list[ClusterInfo], + cc_clusters: list[ClusterInfo], + methods: list[str], +) -> tuple[list[ClusterInfo], list[str]]: if "cos_angle" in methods and "correlation" not in methods: clusters = cos_angle_clusters ctype = ["cos"] * len(clusters) @@ -83,7 +91,12 @@ def clusters_and_types(cos_angle_clusters, cc_clusters, methods): return clusters, ctype -def get_subclusters(params, ids_to_identifiers_map, cos_angle_clusters, cc_clusters): +def get_subclusters( + params: iotbx.phil.scope_extract, + ids_to_identifiers_map: dict[int, str], + cos_angle_clusters: list[ClusterInfo], + cc_clusters: list[ClusterInfo], +) -> list[tuple[str, list[str], ClusterInfo]]: subclusters = [] min_completeness = params.min_completeness @@ -146,7 +159,14 @@ def get_subclusters(params, ids_to_identifiers_map, cos_angle_clusters, cc_clust return subclusters -def output_cluster(new_folder, experiments, reflections, ids): +def output_cluster( + new_folder: pathlib.Path, + experiments: ExperimentList, + reflections: list[flex.reflection_table], + ids: list[str], +) -> None: + if not new_folder.parent.exists(): + pathlib.Path.mkdir(new_folder.parent) expts = copy.deepcopy(experiments) expts.select_on_experiment_identifiers(ids) @@ -157,14 +177,19 @@ def output_cluster(new_folder, experiments, reflections, ids): joint_refl = flex.reflection_table.concat(refl) - if not pathlib.Path.exists(new_folder): + if not new_folder.exists(): pathlib.Path.mkdir(new_folder) expts.as_file(new_folder / "cluster.expt") joint_refl.as_file(new_folder / "cluster.refl") -def output_hierarchical_clusters(params, MCA, experiments, reflections): +def output_hierarchical_clusters( + params: iotbx.phil.scope_extract, + MCA: CorrelationMatrix, + experiments: ExperimentList, + reflections: list[flex.reflection_table], +) -> None: cwd = pathlib.Path.cwd() # First get subclusters that meet the required thresholds @@ -180,11 +205,11 @@ def output_hierarchical_clusters(params, MCA, experiments, reflections): # if not doing distinct cluster analysis, can now output clusters if not params.clustering.hierarchical.distinct_clusters: for c, cluster_identifiers, cluster in subclusters: - cluster_dir = cwd / f"{c}_clusters/cluster_{cluster.cluster_id}" + output_dir = cwd / f"{c}_clusters/cluster_{cluster.cluster_id}" logger.info(f"Outputting {c} cluster {cluster.cluster_id}:") logger.info(cluster) output_cluster( - cluster_dir, + output_dir, experiments, reflections, cluster_identifiers, diff --git a/src/xia2/cli/cluster_analysis.py b/src/xia2/cli/cluster_analysis.py index 30d8d7070..fafe665c5 100644 --- a/src/xia2/cli/cluster_analysis.py +++ b/src/xia2/cli/cluster_analysis.py @@ -16,6 +16,7 @@ ) from dials.util.options import ArgumentParser, reflections_and_experiments_from_files from dials.util.version import dials_version +from dxtbx.model import ExperimentList from jinja2 import ChoiceLoader, Environment, PackageLoader import xia2.Handlers.Streams @@ -145,11 +146,6 @@ def run(args=sys.argv[1:]): logger.info(tabulate(MCA.cos_table, headers="firstrow", tablefmt="rst")) cwd = pathlib.Path.cwd() - if not pathlib.Path.exists(cwd / "cos_clusters"): - pathlib.Path.mkdir(cwd / "cos_clusters") - if not pathlib.Path.exists(cwd / "cc_clusters"): - pathlib.Path.mkdir(cwd / "cc_clusters") - # First get any specific requested/excluded clusters # These are options that are only available to xia2.cluster_analysis @@ -251,11 +247,7 @@ def run(args=sys.argv[1:]): if "hierarchical" in params.clustering.method: output_hierarchical_clusters(params, MCA, experiments, reflections) if "coordinate" in params.clustering.method: - from dxtbx.model import ExperimentList - clusters = MCA.significant_clusters - if not pathlib.Path.exists(cwd / "coordinate_clusters"): - pathlib.Path.mkdir(cwd / "coordinate_clusters") count = 0 for c in clusters: if c.completeness < params.clustering.min_completeness: @@ -266,6 +258,11 @@ def run(args=sys.argv[1:]): continue if count >= params.clustering.max_output_clusters: continue + # for the first cluster, make the directory if not exists + if not count and not pathlib.Path.exists( + cwd / "coordinate_clusters" + ): + pathlib.Path.mkdir(cwd / "coordinate_clusters") cluster_dir = f"coordinate_clusters/cluster_{c.cluster_id}" logger.info(f"Outputting: {cluster_dir}") if not pathlib.Path.exists(cwd / cluster_dir): diff --git a/tests/regression/test_cluster_analysis.py b/tests/regression/test_cluster_analysis.py index 6cafd65d6..9fd74cda4 100644 --- a/tests/regression/test_cluster_analysis.py +++ b/tests/regression/test_cluster_analysis.py @@ -8,7 +8,6 @@ from dials.util.options import ArgumentParser from libtbx import phil -from xia2.cli.multiplex import run as run_multiplex from xia2.Modules.MultiCrystalAnalysis import MultiCrystalAnalysis phil_scope = phil.parse( @@ -113,34 +112,25 @@ def test_serial_data(dials_data, tmp_path, output_clusters, interesting_clusters ) assert not result_generate_scaled.returncode and not result_generate_scaled.stderr result = subprocess.run(args_test_clustering, cwd=tmp_path, capture_output=True) - assert not result.returncode # and not result.stderr + assert not result.returncode and not result.stderr check_output(tmp_path, output_clusters, interesting_clusters) def test_rotation_data(dials_data, run_in_tmp_path): rot = dials_data("vmxi_proteinase_k_sweeps", pathlib=True) - expt_1 = os.fspath(rot / "experiments_0.expt") - expt_2 = os.fspath(rot / "experiments_1.expt") - expt_3 = os.fspath(rot / "experiments_2.expt") - expt_4 = os.fspath(rot / "experiments_3.expt") - refl_1 = os.fspath(rot / "reflections_0.refl") - refl_2 = os.fspath(rot / "reflections_1.refl") - refl_3 = os.fspath(rot / "reflections_2.refl") - refl_4 = os.fspath(rot / "reflections_3.refl") - expt_scaled = os.fspath(run_in_tmp_path / "scaled.expt") - refl_scaled = os.fspath(run_in_tmp_path / "scaled.refl") - run_multiplex( - [ - expt_1, - refl_1, - expt_2, - refl_2, - expt_3, - refl_3, - expt_4, - refl_4, - ] + + # First scale the data to get suitable input + cmd = "dials.scale" + if os.name == "nt": + cmd += ".bat" + result = subprocess.run( + [cmd] + + [rot / f"experiments_{i}.expt" for i in range(0, 4)] + + [rot / f"reflections_{i}.refl" for i in range(0, 4)], + capture_output=True, ) + assert not result.returncode + cmd = "xia2.cluster_analysis" if os.name == "nt": cmd += ".bat" @@ -150,16 +140,36 @@ def test_rotation_data(dials_data, run_in_tmp_path): "clustering.min_cluster_size=2", "clustering.hierarchical.method=cos_angle+correlation", "clustering.output_clusters=True", - expt_scaled, - refl_scaled, + "scaled.expt", + "scaled.refl", "output.json=xia2.cluster_analysis.json", ] result = subprocess.run(args_clustering, capture_output=True) - assert not result.returncode # and not result.stderr + assert not result.returncode and not result.stderr assert (run_in_tmp_path / "xia2.cluster_analysis.json").is_file() assert (run_in_tmp_path / "xia2.cluster_analysis.log").is_file() assert (run_in_tmp_path / "xia2.cluster_analysis.html").is_file() assert (run_in_tmp_path / "cc_clusters" / "cluster_2").exists() + assert not (run_in_tmp_path / "coordinate_clusters").exists() + # now run coordinate clustering + args_clustering = [ + cmd, + "clustering.method=coordinate", + "clustering.min_cluster_size=2", + "clustering.output_clusters=True", + "scaled.expt", + "scaled.refl", + "output.json=xia2.cluster_analysis.json", + ] + result = subprocess.run(args_clustering, capture_output=True) + assert not result.returncode and not result.stderr + assert (run_in_tmp_path / "coordinate_clusters" / "cluster_0").exists() + assert ( + run_in_tmp_path / "coordinate_clusters" / "cluster_0" / "cluster.refl" + ).exists() + assert ( + run_in_tmp_path / "coordinate_clusters" / "cluster_0" / "cluster.expt" + ).exists() def check_output(main_dir, output_clusters=True, interesting_clusters=False):