Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeilstenedmands committed Oct 16, 2023
1 parent 90f2688 commit 8b2e3f3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
10 changes: 5 additions & 5 deletions src/dials/algorithms/indexing/assign_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __call__(self, reflections, experiments, d_min=None):

miller_indices = result.miller_indices()
crystal_ids = result.crystal_ids()

expt_ids = flex.int(crystal_ids.size(), -1)
print(set(crystal_ids))
expt_ids = flex.int(crystal_ids.size(), 0)
for i_cryst, cryst in enumerate(experiments.crystals()):
sel_cryst = crystal_ids == i_cryst
for i_expt in experiments.where(crystal=cryst, imageset=imgset):
Expand All @@ -64,7 +64,7 @@ def __call__(self, reflections, experiments, d_min=None):
reflections["miller_index"].set_selected(
isel.select(sel_imgset), miller_indices
)
# reflections["id"].set_selected(isel.select(sel_imgset), expt_ids)
reflections["id"].set_selected(isel.select(sel_imgset), expt_ids)
reflections.unset_flags(
flex.bool(reflections.size(), True), reflections.flags.indexed
)
Expand Down Expand Up @@ -96,8 +96,8 @@ def __call__(self, reflections, experiments, d_min=None):
inside_resolution_limit = d_spacings > d_min
else:
inside_resolution_limit = flex.bool(reciprocal_lattice_points.size(), True)
sel = inside_resolution_limit & ~reflections.get_flags(
reflections.flags.indexed
sel = inside_resolution_limit & (
~reflections.get_flags(reflections.flags.indexed)
) # (reflections["id"] == -1)
isel = sel.iselection()
rlps = reciprocal_lattice_points.select(isel)
Expand Down
19 changes: 14 additions & 5 deletions src/dials/algorithms/indexing/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,19 @@ def index(self):
min_reflections_for_indexing = cutoff_fraction * len(
self.reflections.select(d_spacings > d_min_indexed)
)
crystal_ids = self.reflections.select(d_spacings > d_min_indexed)["id"]
if (crystal_ids == -1).count(True) < min_reflections_for_indexing:
sel_refls = self.reflections.select(d_spacings > d_min_indexed)
unindexed = ~sel_refls.get_flags(sel_refls.flags.indexed)

if unindexed.count(True) < min_reflections_for_indexing:
logger.info(
"Finish searching for more lattices: %i unindexed reflections remaining.",
(crystal_ids == -1).count(True),
unindexed.count(True),
)
break
else:
logger.info(
f"{unindexed.count(True)} unindexed reflections remaining."
)

n_lattices_previous_cycle = len(experiments)

Expand Down Expand Up @@ -676,9 +682,12 @@ def index(self):
logger.info(
"Removing %d reflections with id %d", sel.count(True), last
)
# refined_reflections["id"].set_selected(sel, -1)
refined_reflections["id"].set_selected(sel, 0)
del refined_reflections.experiment_identifiers()[last]
refined_reflections.unset_flags(
sel.iselection(), refined_reflections.flags.indexed
)
# sel.set_selected(self.reflections["id"] == -1, True)

break

self._unit_cell_volume_sanity_check(experiments, refined_experiments)
Expand Down
45 changes: 36 additions & 9 deletions tests/algorithms/indexing/test_assign_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,18 @@ def __init__(self, experiment, reflections, expected_miller_indices):

# index reflections using simple "global" method
self.reflections_global = copy.deepcopy(reflections)
self.reflections_global["id"] = flex.int(len(self.reflections_global), -1)
# self.reflections_global["id"] = flex.int(len(self.reflections_global), -1)
self.reflections_global["id"] = flex.int(len(self.reflections_global), 0)
self.reflections_global.unset_flags(
flex.bool(self.reflections_global.size(), True),
self.reflections_global.flags.indexed,
)
self.reflections_global["imageset_id"] = flex.int(
len(self.reflections_global), 0
)
index_reflections_global(self.reflections_global, ExperimentList([experiment]))
non_zero_sel = self.reflections_global["miller_index"] != (0, 0, 0)
assert self.reflections_global["id"].select(~non_zero_sel).all_eq(-1)
# assert self.reflections_global["id"].select(~non_zero_sel).all_eq(-1)
self.misindexed_global = (
(expected_miller_indices == self.reflections_global["miller_index"])
.select(non_zero_sel)
Expand All @@ -200,10 +205,16 @@ def __init__(self, experiment, reflections, expected_miller_indices):

# index reflections using xds-style "local" method
self.reflections_local = copy.deepcopy(reflections)
self.reflections_local["id"] = flex.int(len(self.reflections_local), -1)
# self.reflections_local["id"] = flex.int(len(self.reflections_local), -1)
self.reflections_local["id"] = flex.int(len(self.reflections_local), 0)
self.reflections_local.unset_flags(
flex.bool(self.reflections_local.size(), True),
self.reflections_local.flags.indexed,
)

index_reflections_local(self.reflections_local, ExperimentList([experiment]))
non_zero_sel = self.reflections_local["miller_index"] != (0, 0, 0)
assert self.reflections_local["id"].select(~non_zero_sel).all_eq(-1)
# assert self.reflections_local["id"].select(~non_zero_sel).all_eq(-1)
self.misindexed_local = (
(expected_miller_indices == self.reflections_local["miller_index"])
.select(non_zero_sel)
Expand Down Expand Up @@ -240,11 +251,21 @@ def test_index_reflections(dials_regression: Path):
reflections.centroid_px_to_mm(experiments)
reflections.map_centroids_to_reciprocal_space(experiments)
reflections["imageset_id"] = flex.int(len(reflections), 0)
reflections["id"] = flex.int(len(reflections), -1)
# reflections["id"] = flex.int(len(reflections), -1)
reflections["id"] = flex.int(len(reflections), 0)
reflections.unset_flags(
flex.bool(reflections.size(), True), reflections.flags.indexed
)
AssignIndicesGlobal(tolerance=0.3)(reflections, experiments)
assert "miller_index" in reflections
counts = reflections["id"].counts()
assert dict(counts) == {-1: 1390, 0: 114692}
# counts = reflections["id"].counts()
indexed = reflections.get_flags(reflections.flags.indexed)
assert indexed.count(True) == 114692
assert indexed.count(False) == 1390
indexed = reflections["miller_index"] != (0, 0, 0)
assert indexed.count(True) == 114692
assert indexed.count(False) == 1390
# assert dict(counts) == {-1: 1390, 0: 114692}


def test_local_multiple_rotations(dials_data):
Expand All @@ -260,7 +281,10 @@ def test_local_multiple_rotations(dials_data):
# Generate some predicted reflections
reflections = flex.reflection_table.from_predictions(experiments[0], dmin=4)
reflections["imageset_id"] = flex.int(len(reflections), 0)
reflections["id"] = flex.int(len(reflections), -1)
reflections["id"] = flex.int(len(reflections), 0)
reflections.unset_flags(
flex.bool(reflections.size(), True), reflections.flags.indexed
)
reflections["xyzobs.px.value"] = reflections["xyzcal.px"]
reflections["xyzobs.mm.value"] = reflections["xyzcal.mm"]
predicted_miller_indices = reflections["miller_index"]
Expand All @@ -284,7 +308,10 @@ def test_local_multiple_rotations(dials_data):

# Reset miller indices and re-map to reciprocal space
reflections["miller_index"] = flex.miller_index(len(reflections), (0, 0, 0))
reflections["id"] = flex.int(len(reflections), -1)
reflections["id"] = flex.int(len(reflections), 0)
reflections.unset_flags(
flex.bool(reflections.size(), True), reflections.flags.indexed
)
reflections.centroid_px_to_mm(experiments)
reflections.map_centroids_to_reciprocal_space(experiments)

Expand Down

0 comments on commit 8b2e3f3

Please sign in to comment.