Skip to content

Commit

Permalink
Renaming, bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeilstenedmands committed Apr 10, 2024
1 parent 2ff4d5a commit ef78a47
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
52 changes: 32 additions & 20 deletions src/xia2/Modules/SSX/batch_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from xia2.Modules.SSX.data_reduction_definitions import FilePair


class BatchScale(object):
class BatchScaleReindex(object):
def __init__(
self, batches: List[FilePair], reference: Path, space_group: sgtbx.space_group
):
# The purpose here is to do quick KB scaling of batches of datasets,
# for the purpose of testing reindexing against a reference.
phil_scope = phil.parse(
"""
include scope dials.command_line.scale.phil_scope
Expand All @@ -48,9 +50,11 @@ def __init__(
params.reflection_selection.method = "intensity_ranges"
params.reflection_selection.Isigma_range = [2, 0]
params.cut_data.partiality_cutoff = 0.25
params.output.html = None
self.params = params
self.input_sg = space_group
self.reference = reference
assert self.reference is not None
self.input_batches = batches
self._experiments = ExperimentList([])
self._reflections = []
Expand All @@ -60,7 +64,9 @@ def __init__(
# create a single table and expt per batch
all_expts = ExperimentList([])

class SANoStats(ScalingAlgorithm):
class ScalingAlgorithmNoStats(ScalingAlgorithm):
# merging stats calculation is very slow for lots of data.
# avoid as not needed here.
def calculate_merging_stats(self):
pass

Expand All @@ -72,8 +78,9 @@ def calculate_merging_stats(self):
table["intensity.sum.value"] / (table["intensity.sum.variance"] ** 0.5)
) > 2.0
n_sel = sel.count(True)
if (n_sel / sel.size()) < 0.05 and n_sel > 10:
params.reflection_selection.Isigma_range = [0, 0]
if (n_sel / sel.size()) < 0.05 or n_sel < 10: # revert to safe defaults
self.params.reflection_selection.Isigma_range = [-5, 0]
self.params.reflection_selection.method = "quasi_random"
else:
table = table.select(
(
Expand All @@ -98,37 +105,42 @@ def calculate_merging_stats(self):
table.experiment_identifiers()[i] = str(i)
self._reflections.append(table)

wavelength = np.mean([expt.beam.get_wavelength() for expt in expts])
self.wavelength = np.mean([expt.beam.get_wavelength() for expt in expts])
best_unit_cell = determine_best_unit_cell(all_expts)
for expt in self._experiments:
expt.crystal.set_unit_cell(best_unit_cell)
expt.beam.set_wavelength(wavelength)
expt.beam.set_wavelength(self.wavelength)

self.algorithm = SANoStats(params, self._experiments, self._reflections)
self.algorithm = ScalingAlgorithmNoStats(
self.params, self._experiments, self._reflections
)

def run(self):
self.algorithm.run()
del self._reflections
if self.reference:
wavelength = np.mean(
[expt.beam.get_wavelength() for expt in self._experiments]
)
reference_miller_set = intensities_from_reference_file(
os.fspath(self.reference), wavelength=wavelength
)
test_miller_set = self.algorithm.scaled_miller_array
change_of_basis_op = determine_reindex_operator_against_reference(
test_miller_set, reference_miller_set
)
reference_miller_set = intensities_from_reference_file(
os.fspath(self.reference), wavelength=self.wavelength
)
test_miller_set = self.algorithm.scaled_miller_array
self.change_of_basis_op = determine_reindex_operator_against_reference(
test_miller_set, reference_miller_set
)
if self.change_of_basis_op.as_abc() != "a,b,c":
for i, fp in enumerate(self.input_batches):
expts = load.experiment_list(fp.expt, check_format=False)
refl = flex.reflection_table.from_file(fp.refl)
for expt in expts:
expt.crystal = expt.crystal.change_basis(change_of_basis_op)
expt.crystal = expt.crystal.change_basis(self.change_of_basis_op)
expt.crystal.set_space_group(self.input_sg)
expts.as_file(f"processed_{i}.expt")

refl["miller_index"] = change_of_basis_op.apply(refl["miller_index"])
refl["miller_index"] = self.change_of_basis_op.apply(
refl["miller_index"]
)
refl.as_file(f"processed_{i}.refl")
self._output_expt_files.append(f"processed_{i}.expt")
self._output_refl_files.append(f"processed_{i}.refl")
else: # don't need to do anything
for i, fp in enumerate(self.input_batches):
self._output_expt_files.append(fp.expt)
self._output_refl_files.append(fp.refl)
4 changes: 2 additions & 2 deletions src/xia2/Modules/SSX/data_reduction_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from xia2.Driver.timing import record_step
from xia2.Handlers.Files import FileHandler
from xia2.Modules.SSX.batch_cosym import BatchCosym
from xia2.Modules.SSX.batch_scale import BatchScale
from xia2.Modules.SSX.batch_scale import BatchScaleReindex
from xia2.Modules.SSX.data_reduction_definitions import FilePair, ReductionParams
from xia2.Modules.SSX.reporting import condensed_unit_cell_info
from xia2.Modules.SSX.util import log_to_file, run_in_directory
Expand Down Expand Up @@ -888,7 +888,7 @@ def scale_reindex(
with run_in_directory(working_directory), record_step("scale_reindex"), log_to_file(
logfile
):
s = BatchScale(batches_to_scale, reference, space_group)
s = BatchScaleReindex(batches_to_scale, reference, space_group)
s.run()
xia2_logger.info("Reindexed against reference file")
outfiles = []
Expand Down

0 comments on commit ef78a47

Please sign in to comment.