From dd9e19acaa476c7f4f1a983224efffb86f8b44ca Mon Sep 17 00:00:00 2001 From: James Beilsten-Edmands <30625594+jbeilstenedmands@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:05:03 +0100 Subject: [PATCH] Use a weighted cchalf fit to improve resolution estimates --- newsfragments/XXX.feature | 1 + src/dials/util/resolution_analysis.py | 102 ++++++++++-------- .../command_line/test_estimate_resolution.py | 2 +- 3 files changed, 59 insertions(+), 46 deletions(-) create mode 100644 newsfragments/XXX.feature diff --git a/newsfragments/XXX.feature b/newsfragments/XXX.feature new file mode 100644 index 0000000000..02fa9289b5 --- /dev/null +++ b/newsfragments/XXX.feature @@ -0,0 +1 @@ +``dials.estimate_resolution``: Improved cc1/2 fitting by using a weighted tanh fit diff --git a/src/dials/util/resolution_analysis.py b/src/dials/util/resolution_analysis.py index 80329891e3..7d9f56710c 100644 --- a/src/dials/util/resolution_analysis.py +++ b/src/dials/util/resolution_analysis.py @@ -9,13 +9,16 @@ import math import typing +import numpy as np +import scipy.optimize + import iotbx.merging_statistics import iotbx.mtz import iotbx.phil from cctbx import miller, uctbx from cctbx.array_family import flex from iotbx.reflection_file_utils import label_table -from scitbx.math import curve_fitting, five_number_summary +from scitbx.math import curve_fitting from dials.algorithms.scaling.scaling_library import determine_best_unit_cell from dials.report import plots @@ -42,7 +45,7 @@ class metrics(enum.Enum): COMPLETENESS = "completeness" -def polynomial_fit(x, y, degree=5): +def polynomial_fit(x, y, degree=5, n_obs=None): """ Fit a polynomial to the values y(x) and return this fit @@ -57,36 +60,52 @@ def polynomial_fit(x, y, degree=5): return f(x) -def tanh_fit(x, y, iqr_multiplier=None): +def tanh_inzspace(x, r, s0): + # from the x's, calculate y and then transform into fisher z-space. + xprime = (x - s0) / r + yprime = 0.5 * (1 - np.tanh(xprime)) + # avoid math errors if y>=1. + delta = 1e-9 + zprime = np.array([math.atanh(yp) if yp < 1 else 1.0 - delta for yp in yprime]) + return zprime + + +def tanh_fit(x, y, degree=None, n_obs=None): """ Fit a tanh function to the values y(x) and return this fit x, y should be iterables containing floats of the same size. This is used for fitting a curve to CC½. """ + assert n_obs + if (n_obs > 3).count(True) < 3: + raise RuntimeError("Not enough reflections for fitting") - tf = curve_fitting.tanh_fit(x, y) - f = curve_fitting.tanh(*tf.params) + # To do a weighted tanh_fit to cc1/2, first use the fisher Z-transformation + # which has the convenient property of having symmetric errors. So do the + # cc1/2 fit in z-space + standard_errors = [ + 1 / math.sqrt(n - 3.0) if n > 4 else 100 for n in n_obs + ] # 100 or another large number. + p0 = np.array([0.2, 0.4]) # starting parameter estimates + sigma = np.array(standard_errors) + x = np.array(x) + # avoid math errors if y>=1. + delta = 1e-9 + yinz = np.array([math.atanh(yi) if yi < 1 else (1 - delta) for yi in y]) - if iqr_multiplier: - assert iqr_multiplier > 0 - yc = f(x) - dy = y - yc + result = scipy.optimize.curve_fit(tanh_inzspace, x, yinz, p0, sigma=sigma) - min_x, q1_x, med_x, q3_x, max_x = five_number_summary(dy) - iqr_x = q3_x - q1_x - cut_x = iqr_multiplier * iqr_x - outliers = (dy > q3_x + cut_x) | (dy < q1_x - cut_x) - if outliers.count(True) > 0: - xo = x.select(~outliers) - yo = y.select(~outliers) - tf = curve_fitting.tanh_fit(xo, yo) - f = curve_fitting.tanh(*tf.params) + r = result[0][0] + s0 = result[0][1] - return f(x) + xprime = [(xi - s0) / r for xi in x] + yprime = [0.5 * (1 - np.tanh(xp)) for xp in xprime] + + return flex.double(yprime) -def log_fit(x, y, degree=5): +def log_fit(x, y, degree=5, n_obs=None): """Fit the values log(y(x)) then return exp() to this fit. x, y should be iterables containing floats of the same size. The order is the order @@ -99,7 +118,7 @@ def log_fit(x, y, degree=5): return flex.exp(f(x)) -def log_inv_fit(x, y, degree=5): +def log_inv_fit(x, y, degree=5, n_obs=None): """Fit the values log(1 / y(x)) then return the inverse of this fit. x, y should be iterables, the order of the polynomial for the transformed @@ -112,7 +131,7 @@ def log_inv_fit(x, y, degree=5): return 1 / flex.exp(f(x)) -def resolution_fit_from_merging_stats(merging_stats, metric, model, limit, sel=None): +def resolution_fit_from_merging_stats(merging_stats, metric, model, limit): """Estimate a resolution limit based on the input `metric` The function defined by `model` will be fit to the selected `metric` which has been @@ -130,20 +149,21 @@ def resolution_fit_from_merging_stats(merging_stats, metric, model, limit, sel=N input x (d_star_sq) and y (the metric to be fitted) values, returning the fitted y(x) values. limit (float): The resolution limit criterion. - sel (scitbx.array_family.flex.bool): An optional selection to apply to the - `merging_stats` bins. Returns: The estimated resolution limit in units of Å^-1 """ y_obs = flex.double(getattr(b, metric) for b in merging_stats.bins).reversed() + n_obs = flex.double( + getattr(b, "cc_one_half_n_refl") for b in merging_stats.bins + ).reversed() d_star_sq = flex.double( uctbx.d_as_d_star_sq(b.d_min) for b in merging_stats.bins ).reversed() - return resolution_fit(d_star_sq, y_obs, model, limit, sel=sel) + return resolution_fit(d_star_sq, y_obs, model, limit, n_obs) -def resolution_fit(d_star_sq, y_obs, model, limit, sel=None): +def resolution_fit(d_star_sq, y_obs, model, limit, n_obs): """Estimate a resolution limit based on the input merging statistics The function defined by `model` will be fit to the input `d_star_sq` and `y_obs`. @@ -159,24 +179,16 @@ def resolution_fit(d_star_sq, y_obs, model, limit, sel=None): (d_star_sq) and y (the metric to be fitted) values, returning the fitted y(x) values. limit (float): The resolution limit criterion. - sel (scitbx.array_family.flex.bool): An optional selection to apply to the - `d_star_sq` and `y_obs` values. + n_obs (scitbx.array_family.flex.int): The number of observations in each bin + relevant to the fit. Returns: The estimated resolution limit in units of Å^-1 Raises: - RuntimeError: Raised if no `y_obs` values remain after application of the - selection `sel` + RuntimeError: Raised if not enough observations to perform tanh_fit` """ - if not sel: - sel = flex.bool(len(d_star_sq), True) - sel &= y_obs > 0 - y_obs = y_obs.select(sel) - d_star_sq = d_star_sq.select(sel) - - if not len(y_obs): - raise RuntimeError("No reflections left for fitting") - y_fit = model(d_star_sq, y_obs, 6) + + y_fit = model(d_star_sq, y_obs, degree=6, n_obs=n_obs) logger.debug( tabulate( [("d*2", "d", "obs", "fit")] @@ -255,14 +267,11 @@ def resolution_cc_half( Returns: The estimated resolution limit in units of Å^-1 """ - sel = _get_cc_half_significance(merging_stats, cc_half_method) metric = "cc_one_half_sigma_tau" if cc_half_method == "sigma_tau" else "cc_one_half" - result = resolution_fit_from_merging_stats( - merging_stats, metric, model, limit, sel=sel - ) + result = resolution_fit_from_merging_stats(merging_stats, metric, model, limit) critical_values = _get_cc_half_critical_values(merging_stats, cc_half_method) if critical_values: - result = result._replace(critical_values=critical_values.select(sel)) + result = result._replace(critical_values=critical_values) return result @@ -695,5 +704,8 @@ def _resolution_cc_ref(self, limit=None): d_star_sq = flex.double( 1 / b.d_min**2 for b in self._merging_statistics.bins ).reversed() + n_obs = flex.double( + getattr(b, "cc_one_half_n_refl") for b in self._merging_statistics.bins + ).reversed() - return resolution_fit(d_star_sq, cc_s, fit, limit) + return resolution_fit(d_star_sq, cc_s, fit, limit, n_obs) diff --git a/tests/command_line/test_estimate_resolution.py b/tests/command_line/test_estimate_resolution.py index 89ade49e85..f14180559b 100644 --- a/tests/command_line/test_estimate_resolution.py +++ b/tests/command_line/test_estimate_resolution.py @@ -109,7 +109,7 @@ def test_handle_fit_failure(dials_data, run_in_tmp_path, capsys): captured = capsys.readouterr() expected_output = ( - "Resolution fit against cc_half failed: No reflections left for fitting", + "Resolution fit against cc_half failed: Not enough reflections for fitting", "Resolution Mn(I/sig): 0.62", ) for line in expected_output: