Skip to content

Commit

Permalink
Replace CSVs with chi2.ppf for GLRT test (#456)
Browse files Browse the repository at this point in the history
* Replace CSVs with `chi2.ppf` for GLRT test

* default 0.001

* test smaller alphas

* specify log in temp dir
  • Loading branch information
scottstanie authored Oct 21, 2024
1 parent 53df4b8 commit eb00339
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 4,050 deletions.
63 changes: 9 additions & 54 deletions src/dolphin/shp/_glrt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import csv
from functools import lru_cache
from math import log
from pathlib import Path
from typing import Optional

import numba
import numpy as np
from numpy.typing import ArrayLike
from scipy import stats

from dolphin._types import Strides
from dolphin.utils import _get_slices, compute_out_shape
Expand Down Expand Up @@ -78,7 +76,8 @@ def estimate_neighbors(
half_row, half_col = halfwin_rowcol
rows, cols = mean.shape

threshold = get_cutoff(alpha=alpha, N=nslc)
# 1 Degree of freedom, regardless of N
threshold = stats.chi2.ppf(1 - alpha, df=1)

strides_rowcol = (strides["y"], strides["x"])
out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol))
Expand All @@ -88,6 +87,7 @@ def estimate_neighbors(
return _loop_over_pixels(
mean,
var,
nslc,
halfwin_rowcol,
strides_rowcol,
threshold,
Expand All @@ -96,47 +96,18 @@ def estimate_neighbors(
)


def get_cutoff(alpha: float, N: int) -> float:
r"""Compute the upper cutoff for the GLRT test statistic.
Statistic is
\[
2\log(\sigma_{pooled}) - \log(\sigma_{p}) -\log(\sigma_{q})
\]
Parameters
----------
alpha: float
Significance level (0 < alpha < 1).
N: int
Number of samples.
Returns
-------
float
Cutoff value for the GLRT test statistic.
"""
n_alpha_to_cutoff = _read_cutoff_csv()
try:
return n_alpha_to_cutoff[(N, alpha)]
except KeyError as e:
msg = f"Not implemented for {N = }, {alpha = }"
raise NotImplementedError(msg) from e


@numba.njit(nogil=True)
def _compute_glrt_test_stat(scale_1, scale_2):
def _compute_glrt_test_stat(scale_sq_1, scale_sq_2, N):
"""Compute the GLRT test statistic."""
scale_pooled = (scale_1 + scale_2) / 2
return 2 * log(scale_pooled) - log(scale_1) - log(scale_2)
scale_pooled = (scale_sq_1 + scale_sq_2) / 2
return N * (2 * log(scale_pooled) - log(scale_sq_1) - log(scale_sq_2))


@numba.njit(nogil=True, parallel=True)
def _loop_over_pixels(
mean: ArrayLike,
var: ArrayLike,
N: int,
halfwin_rowcol: tuple[int, int],
strides_rowcol: tuple[int, int],
threshold: float,
Expand Down Expand Up @@ -180,27 +151,11 @@ def _loop_over_pixels(
continue
scale_2 = scale_squared[in_r2, in_c2]

T = _compute_glrt_test_stat(scale_1, scale_2)
T = _compute_glrt_test_stat(scale_1, scale_2, N)

is_shp[out_r, out_c, r_off, c_off] = threshold > T
if prune_disconnected:
# For this pixel, prune the groups not connected to the center
remove_unconnected(is_shp[out_r, out_c], inplace=True)

return is_shp


@lru_cache
def _read_cutoff_csv():
filename = Path(__file__).parent / "glrt_cutoffs.csv"

result = {}
with open(filename) as file:
reader = csv.DictReader(file)
for row in reader:
n = int(row["N"])
alpha = float(row["alpha"])
cutoff = float(row["cutoff"])
result[(n, alpha)] = cutoff

return result
Loading

0 comments on commit eb00339

Please sign in to comment.