Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix early stopping, by making it stricter #242

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
14 changes: 14 additions & 0 deletions expan/core/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
cache_sampling_results = False
sampling_results = {} # memorized sampling results

OBRIEN_FLEMING_DIVISION_FACTOR = 100

def obrien_fleming(information_fraction, alpha=0.05):
""" Calculate an approximation of the O'Brien-Fleming alpha spending function.
Expand All @@ -33,6 +34,19 @@ def obrien_fleming(information_fraction, alpha=0.05):
:return: redistributed alpha value at the time point with the given information fraction
:rtype: float
"""

alpha = alpha/OBRIEN_FLEMING_DIVISION_FACTOR
"""
The following tests needed to be adjusted to take account of this correction:
- tests/tests_core/test_early_stopping.py::
GroupSequentialTestCases::
test_obrien_fleming
test_group_sequential
test_group_sequential_actual_size_larger_than_estimated
- tests_core/test_experiment.py::
StatisticalTestTestCases::
test_group_sequential
"""
return (1 - norm.cdf(norm.ppf(1 - alpha / 2) / np.sqrt(information_fraction))) * 2


Expand Down
7 changes: 5 additions & 2 deletions expan/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __repr__(self):
return self.toJson()


def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key):
def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key, tol=None):
""" Find the value of lookup key where the dictionary contains condition key = condition value.

:param items: list of dictionaries
Expand All @@ -31,7 +31,10 @@ def find_value_by_key_with_condition(items, condition_key, condition_value, look

:return: lookup value or found value for the lookup key
"""
return [item[lookup_key] for item in items if item[condition_key] == condition_value][0]
if tol is None:
return [item[lookup_key] for item in items if item[condition_key] == condition_value][0]
else:
return [item[lookup_key] for item in items if abs(item[condition_key]-condition_value) < tol][0]


def is_nan(obj):
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_core/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def test_obrien_fleming(self):
""" Check the O'Brien-Fleming spending function."""
# Check array as input
res_1 = es.obrien_fleming(np.linspace(0, 1, 5 + 1)[1:])
expected_res = [1.17264468e-05, 1.94191300e-03, 1.13964185e-02, 2.84296308e-02, 5.00000000e-02]
expected_res = [7.1054274e-15,3.7219966e-08,7.0016877e-06,9.9583700e-05,5.0000000e-04]
np.testing.assert_almost_equal(res_1, expected_res)

# Check float as input
res_2 = es.obrien_fleming(0.5)
self.assertAlmostEqual(res_2, 0.005574596680784305)
self.assertAlmostEqual(res_2, 8.5431190077756014e-07)

# Check int as input
res_3 = es.obrien_fleming(1)
self.assertAlmostEqual(res_3, 0.05)
self.assertAlmostEqual(res_3, 0.0005)

def test_group_sequential(self):
""" Check the group sequential function."""
Expand All @@ -60,10 +60,10 @@ def test_group_sequential(self):
self.assertAlmostEqual(res.control_statistics.variance, 0.9373337542827797)

self.assertAlmostEqual(res.delta, -0.15887364780635896)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value')
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value')
np.testing.assert_almost_equal(value025, -0.24461812530841959, decimal=5)
np.testing.assert_almost_equal(value975, -0.07312917030429833, decimal=5)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
np.testing.assert_almost_equal(value025, -0.31130760395377599, decimal=5)
np.testing.assert_almost_equal(value975, -0.0064396916589367081, decimal=5)

self.assertAlmostEqual(res.p, 0.0002863669955157941)
self.assertAlmostEqual(res.statistical_power, 0.9529152504960496)
Expand All @@ -75,10 +75,10 @@ def test_group_sequential_actual_size_larger_than_estimated(self):
"""
res = es.group_sequential(self.rand_s1, self.rand_s2, estimated_sample_size=100)

value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value')
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value')
np.testing.assert_almost_equal (value025, -0.24461812530841959, decimal=5)
np.testing.assert_almost_equal (value975, -0.07312917030429833, decimal=5)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5)
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5)
np.testing.assert_almost_equal (value025, -0.31130760395377599, decimal=5)
np.testing.assert_almost_equal (value975, -0.00643969165893670, decimal=5)


class BayesFactorTestCases(EarlyStoppingTestCase):
Expand Down
9 changes: 5 additions & 4 deletions tests/tests_core/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from expan.core.results import CombinedTestStatistics
from expan.core.statistical_test import *
from expan.core.experiment import Experiment
import expan.core.early_stopping as es
from expan.core.util import generate_random_data, find_value_by_key_with_condition


Expand Down Expand Up @@ -129,10 +130,10 @@ def test_group_sequential(self):

self.assertAlmostEqual(res.result.delta, 0.033053, ndecimals)

lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5, 'value')
upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 97.5, 'value')
self.assertAlmostEqual(lower_bound_ci, -0.007135, ndecimals)
self.assertAlmostEqual(upper_bound_ci, 0.073240, ndecimals)
lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
self.assertAlmostEqual(lower_bound_ci, -0.0383319, ndecimals)
self.assertAlmostEqual(upper_bound_ci, 0.104437, ndecimals)

self.assertEqual(res.result.treatment_statistics.sample_size, 6108)
self.assertEqual(res.result.control_statistics.sample_size, 3892)
Expand Down