Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeilstenedmands committed Dec 13, 2023
1 parent 8b74868 commit 51d1eef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/dials/command_line/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _index_single_imageset(experiments, reflections, params, log_text=None):
# update the identifiers so that the unindexed has id 0, and the rest 1,2,3.. etc
idxr.unindexed_reflections["id"] = flex.int(idxr.unindexed_reflections.size(), 0)
idxr.unindexed_reflections.experiment_identifiers()[0] = unindexed[0].identifier
idxr.unindexed_reflections.clean_experiment_identifiers_map()
if idxr.unindexed_reflections.size():
idxr.unindexed_reflections.clean_experiment_identifiers_map()

idx_refl = idxr.refined_reflections
for id_ in sorted(set(idx_refl["id"]), reverse=True):
Expand All @@ -179,7 +180,8 @@ def _index_single_imageset(experiments, reflections, params, log_text=None):
indexed_experiments = ExperimentList(unindexed)
indexed_experiments.extend(idxr.refined_experiments)
idx_refl.extend(idxr.unindexed_reflections)
idx_refl.assert_experiment_identifiers_are_consistent()
idx_refl.assert_experiment_identifiers_are_consistent(indexed_experiments)

return indexed_experiments, idx_refl


Expand Down
29 changes: 22 additions & 7 deletions tests/algorithms/indexing/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def run_indexing(
assert out_refls.is_file()

experiments_list = load.experiment_list(out_expts, check_format=False)
assert len(experiments_list.crystals()) == n_expected_lattices
assert len([c for c in experiments_list.crystals() if c]) == n_expected_lattices
indexed_reflections = flex.reflection_table.from_file(out_refls)
indexed_reflections.assert_experiment_identifiers_are_consistent(experiments_list)
rmsds = None

for i, experiment in enumerate(experiments_list):
if experiment.crystal is None:
continue
assert unit_cells_are_similar(
experiment.crystal.get_unit_cell(),
expected_unit_cell,
Expand Down Expand Up @@ -609,7 +611,7 @@ def test_refinement_failure_on_max_lattices_a15(dials_data, tmp_path):
experiments_list = load.experiment_list(
tmp_path / "indexed.expt", check_format=False
)
assert len(experiments_list) == 2
assert len([c for c in experiments_list.crystals() if c]) == 2

# now try to reindex with existing model
result = subprocess.run(
Expand All @@ -628,7 +630,7 @@ def test_refinement_failure_on_max_lattices_a15(dials_data, tmp_path):
experiments_list = load.experiment_list(
tmp_path / "indexed.expt", check_format=False
)
assert len(experiments_list) == 2
assert len([c for c in experiments_list.crystals() if c]) == 2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -688,6 +690,7 @@ def test_index_multi_lattice_multi_sweep(dials_data, tmp_path):
"max_lattices=2",
"joint_indexing=False",
"n_macro_cycles=2",
"output.retain_unindexed_experiment=False",
],
cwd=tmp_path,
capture_output=True,
Expand Down Expand Up @@ -959,7 +962,8 @@ def test_index_known_orientation(dials_data, tmp_path):
)


def test_all_expt_ids_have_expts(dials_data, tmp_path):
@pytest.mark.parametrize("retain_unindexed", [True, False])
def test_all_expt_ids_have_expts(dials_data, tmp_path, retain_unindexed):
result = subprocess.run(
[
shutil.which("dials.index"),
Expand All @@ -972,6 +976,7 @@ def test_all_expt_ids_have_expts(dials_data, tmp_path):
"max_lattices=8",
"beam.fix=all",
"detector.fix=all",
f"retain_unindexed_experiment={retain_unindexed}",
],
cwd=tmp_path,
capture_output=True,
Expand All @@ -982,9 +987,18 @@ def test_all_expt_ids_have_expts(dials_data, tmp_path):

refl = flex.reflection_table.from_file(tmp_path / "indexed.refl")
expt = ExperimentList.from_file(tmp_path / "indexed.expt", check_format=False)
assert (refl["id"] != -1).count(True) == refl.get_flags(refl.flags.indexed).count(
True
)
if retain_unindexed:
id_ = None
for i in refl.experiment_identifiers().keys():
if refl.experiment_identifiers()[i] == expt[0].identifier:
id_ = i
break
sel = refl["id"] == id_
assert sel.count(False) == refl.get_flags(refl.flags.indexed).count(True)
else:
assert (refl["id"] != -1).count(True) == refl.get_flags(
refl.flags.indexed
).count(True)
refl.assert_experiment_identifiers_are_consistent(expt)

assert flex.max(refl["id"]) + 1 == len(expt)
Expand All @@ -1002,6 +1016,7 @@ def test_multi_lattice_multi_sweep_joint(dials_data, tmp_path):
dials_data("l_cysteine_dials_output", pathlib=True) / "indexed.expt",
dials_data("l_cysteine_dials_output", pathlib=True) / "indexed.refl",
"max_lattices=2",
"output.retain_unindexed_experiment=False",
],
cwd=tmp_path,
capture_output=True,
Expand Down

0 comments on commit 51d1eef

Please sign in to comment.