Skip to content

Commit

Permalink
Add type hints and expand tests for cluster_analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeilstenedmands committed Oct 21, 2024
1 parent ddcada2 commit 7c34e6a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 42 deletions.
39 changes: 32 additions & 7 deletions src/xia2/Modules/MultiCrystal/cluster_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import logging
import pathlib

import iotbx.phil
from dials.algorithms.correlation.analysis import CorrelationMatrix
from dials.algorithms.correlation.cluster import ClusterInfo
from dials.array_family import flex
from dxtbx.model import ExperimentList

from xia2.Modules.MultiCrystalAnalysis import MultiCrystalAnalysis

Expand Down Expand Up @@ -65,7 +69,11 @@
"""


def clusters_and_types(cos_angle_clusters, cc_clusters, methods):
def clusters_and_types(
cos_angle_clusters: list[ClusterInfo],
cc_clusters: list[ClusterInfo],
methods: list[str],
) -> tuple[list[ClusterInfo], list[str]]:
if "cos_angle" in methods and "correlation" not in methods:
clusters = cos_angle_clusters
ctype = ["cos"] * len(clusters)
Expand All @@ -83,7 +91,12 @@ def clusters_and_types(cos_angle_clusters, cc_clusters, methods):
return clusters, ctype


def get_subclusters(params, ids_to_identifiers_map, cos_angle_clusters, cc_clusters):
def get_subclusters(
params: iotbx.phil.scope_extract,
ids_to_identifiers_map: dict[int, str],
cos_angle_clusters: list[ClusterInfo],
cc_clusters: list[ClusterInfo],
) -> list[tuple[str, list[str], ClusterInfo]]:
subclusters = []

min_completeness = params.min_completeness
Expand Down Expand Up @@ -146,7 +159,14 @@ def get_subclusters(params, ids_to_identifiers_map, cos_angle_clusters, cc_clust
return subclusters


def output_cluster(new_folder, experiments, reflections, ids):
def output_cluster(
new_folder: pathlib.Path,
experiments: ExperimentList,
reflections: list[flex.reflection_table],
ids: list[str],
) -> None:
if not new_folder.parent.exists():
pathlib.Path.mkdir(new_folder.parent)
expts = copy.deepcopy(experiments)
expts.select_on_experiment_identifiers(ids)

Expand All @@ -157,14 +177,19 @@ def output_cluster(new_folder, experiments, reflections, ids):

joint_refl = flex.reflection_table.concat(refl)

if not pathlib.Path.exists(new_folder):
if not new_folder.exists():
pathlib.Path.mkdir(new_folder)

expts.as_file(new_folder / "cluster.expt")
joint_refl.as_file(new_folder / "cluster.refl")


def output_hierarchical_clusters(params, MCA, experiments, reflections):
def output_hierarchical_clusters(
params: iotbx.phil.scope_extract,
MCA: CorrelationMatrix,
experiments: ExperimentList,
reflections: list[flex.reflection_table],
) -> None:
cwd = pathlib.Path.cwd()

# First get subclusters that meet the required thresholds
Expand All @@ -180,11 +205,11 @@ def output_hierarchical_clusters(params, MCA, experiments, reflections):
# if not doing distinct cluster analysis, can now output clusters
if not params.clustering.hierarchical.distinct_clusters:
for c, cluster_identifiers, cluster in subclusters:
cluster_dir = cwd / f"{c}_clusters/cluster_{cluster.cluster_id}"
output_dir = cwd / f"{c}_clusters/cluster_{cluster.cluster_id}"
logger.info(f"Outputting {c} cluster {cluster.cluster_id}:")
logger.info(cluster)
output_cluster(
cluster_dir,
output_dir,
experiments,
reflections,
cluster_identifiers,
Expand Down
15 changes: 6 additions & 9 deletions src/xia2/cli/cluster_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from dials.util.options import ArgumentParser, reflections_and_experiments_from_files
from dials.util.version import dials_version
from dxtbx.model import ExperimentList
from jinja2 import ChoiceLoader, Environment, PackageLoader

import xia2.Handlers.Streams
Expand Down Expand Up @@ -145,11 +146,6 @@ def run(args=sys.argv[1:]):
logger.info(tabulate(MCA.cos_table, headers="firstrow", tablefmt="rst"))

cwd = pathlib.Path.cwd()
if not pathlib.Path.exists(cwd / "cos_clusters"):
pathlib.Path.mkdir(cwd / "cos_clusters")
if not pathlib.Path.exists(cwd / "cc_clusters"):
pathlib.Path.mkdir(cwd / "cc_clusters")

# First get any specific requested/excluded clusters

# These are options that are only available to xia2.cluster_analysis
Expand Down Expand Up @@ -251,11 +247,7 @@ def run(args=sys.argv[1:]):
if "hierarchical" in params.clustering.method:
output_hierarchical_clusters(params, MCA, experiments, reflections)
if "coordinate" in params.clustering.method:
from dxtbx.model import ExperimentList

clusters = MCA.significant_clusters
if not pathlib.Path.exists(cwd / "coordinate_clusters"):
pathlib.Path.mkdir(cwd / "coordinate_clusters")
count = 0
for c in clusters:
if c.completeness < params.clustering.min_completeness:
Expand All @@ -266,6 +258,11 @@ def run(args=sys.argv[1:]):
continue
if count >= params.clustering.max_output_clusters:
continue
# for the first cluster, make the directory if not exists
if not count and not pathlib.Path.exists(
cwd / "coordinate_clusters"
):
pathlib.Path.mkdir(cwd / "coordinate_clusters")
cluster_dir = f"coordinate_clusters/cluster_{c.cluster_id}"
logger.info(f"Outputting: {cluster_dir}")
if not pathlib.Path.exists(cwd / cluster_dir):
Expand Down
62 changes: 36 additions & 26 deletions tests/regression/test_cluster_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dials.util.options import ArgumentParser
from libtbx import phil

from xia2.cli.multiplex import run as run_multiplex
from xia2.Modules.MultiCrystalAnalysis import MultiCrystalAnalysis

phil_scope = phil.parse(
Expand Down Expand Up @@ -113,34 +112,25 @@ def test_serial_data(dials_data, tmp_path, output_clusters, interesting_clusters
)
assert not result_generate_scaled.returncode and not result_generate_scaled.stderr
result = subprocess.run(args_test_clustering, cwd=tmp_path, capture_output=True)
assert not result.returncode # and not result.stderr
assert not result.returncode and not result.stderr
check_output(tmp_path, output_clusters, interesting_clusters)


def test_rotation_data(dials_data, run_in_tmp_path):
rot = dials_data("vmxi_proteinase_k_sweeps", pathlib=True)
expt_1 = os.fspath(rot / "experiments_0.expt")
expt_2 = os.fspath(rot / "experiments_1.expt")
expt_3 = os.fspath(rot / "experiments_2.expt")
expt_4 = os.fspath(rot / "experiments_3.expt")
refl_1 = os.fspath(rot / "reflections_0.refl")
refl_2 = os.fspath(rot / "reflections_1.refl")
refl_3 = os.fspath(rot / "reflections_2.refl")
refl_4 = os.fspath(rot / "reflections_3.refl")
expt_scaled = os.fspath(run_in_tmp_path / "scaled.expt")
refl_scaled = os.fspath(run_in_tmp_path / "scaled.refl")
run_multiplex(
[
expt_1,
refl_1,
expt_2,
refl_2,
expt_3,
refl_3,
expt_4,
refl_4,
]

# First scale the data to get suitable input
cmd = "dials.scale"
if os.name == "nt":
cmd += ".bat"
result = subprocess.run(
[cmd]
+ [rot / f"experiments_{i}.expt" for i in range(0, 4)]
+ [rot / f"reflections_{i}.refl" for i in range(0, 4)],
capture_output=True,
)
assert not result.returncode

cmd = "xia2.cluster_analysis"
if os.name == "nt":
cmd += ".bat"
Expand All @@ -150,16 +140,36 @@ def test_rotation_data(dials_data, run_in_tmp_path):
"clustering.min_cluster_size=2",
"clustering.hierarchical.method=cos_angle+correlation",
"clustering.output_clusters=True",
expt_scaled,
refl_scaled,
"scaled.expt",
"scaled.refl",
"output.json=xia2.cluster_analysis.json",
]
result = subprocess.run(args_clustering, capture_output=True)
assert not result.returncode # and not result.stderr
assert not result.returncode and not result.stderr
assert (run_in_tmp_path / "xia2.cluster_analysis.json").is_file()
assert (run_in_tmp_path / "xia2.cluster_analysis.log").is_file()
assert (run_in_tmp_path / "xia2.cluster_analysis.html").is_file()
assert (run_in_tmp_path / "cc_clusters" / "cluster_2").exists()
assert not (run_in_tmp_path / "coordinate_clusters").exists()
# now run coordinate clustering
args_clustering = [
cmd,
"clustering.method=coordinate",
"clustering.min_cluster_size=2",
"clustering.output_clusters=True",
"scaled.expt",
"scaled.refl",
"output.json=xia2.cluster_analysis.json",
]
result = subprocess.run(args_clustering, capture_output=True)
assert not result.returncode and not result.stderr
assert (run_in_tmp_path / "coordinate_clusters" / "cluster_0").exists()
assert (
run_in_tmp_path / "coordinate_clusters" / "cluster_0" / "cluster.refl"
).exists()
assert (
run_in_tmp_path / "coordinate_clusters" / "cluster_0" / "cluster.expt"
).exists()


def check_output(main_dir, output_clusters=True, interesting_clusters=False):
Expand Down

0 comments on commit 7c34e6a

Please sign in to comment.