Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Object-oriented interface for parametric tests. #13

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def from_study(study, verbose=None):
dataset_kwargs = {
"sarica": {
"dwi_metrics": ["md", "fa"],
"target_cols": ["class"],
"target_cols": ["class", "age"],
"label_encode_cols": ["class"],
},
"weston-havens": {"dwi_metrics": ["md", "fa"], "target_cols": ["Age"]},
Expand Down
143 changes: 97 additions & 46 deletions afqinsight/parametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Perform linear modeling at leach node along the tract."""
"""Perform linear modeling at each node along the tract."""

import numpy as np
import pandas as pd
Expand All @@ -11,11 +11,11 @@
def node_wise_regression(
afq_dataset,
tract,
metric,
formula,
group="group",
group=None,
lme=False,
rand_eff="subjectID",
impute="median",
):
"""Model group differences using node-wise regression along the length of the tract.

Expand All @@ -26,13 +26,10 @@ def node_wise_regression(
----------
afq_dataset: AFQDataset
Loaded AFQDataset object

tract: str
String specifying the tract to model

metric: str
String specifying which diffusion metric to use as an outcome
eg. 'fa'

formula: str
An R-style formula <https://www.statsmodels.org/dev/example_formulas.html>
specifying the regression model to fit at each node. This can take the form
Expand All @@ -46,20 +43,23 @@ def node_wise_regression(
mixed-effects models. If using anything other than the default value,
this column must be present in the 'target_cols' of the AFQDataset object

impute: str or None, default='median'
String specifying the imputation strategy to use for missing data.


Returns
-------
tract_dict: dict
A dictionary with the following key-value pairs:

{'tract': tract,
'reference_coefs': coefs_default,
'group_coefs': coefs_treat,
'reference_CI': cis_default,
'group_CI': cis_treat,
'pvals': pvals,
'reject_idx': reject_idx,
'model_fits': fits}
'reference_coefs': coefs_default,
'group_coefs': coefs_treat,
'reference_CI': cis_default,
'group_CI': cis_treat,
'pvals': pvals,
'reject_idx': reject_idx,
'model_fits': fits}

tract: str
The tract described by this dictionary
Expand All @@ -72,7 +72,7 @@ def node_wise_regression(
group_coefs: list of floats
A list of beta-weights representing the average group effect metric
for the treatment group on a diffusion metric at a given location
along the tract
along the tract, if group None this will be a list of zeros.

reference_CI: np.array of np.array
A numpy array containing a series of numpy arrays indicating the
Expand All @@ -82,7 +82,8 @@ def node_wise_regression(
group_CI: np.array of np.array
A numpy array containing a series of numpy arrays indicating the
95% confidence interval around the estimated beta-weight of the
treatment effect at a given location along the tract
treatment effect at a given location along the tract. If group is
None, this will be an array of zeros.

pvals: list of floats
A list of p-values testing whether or not the beta-weight of the
Expand All @@ -96,8 +97,13 @@ def node_wise_regression(
A list of the statsmodels object fit along the length of the nodes

"""
X = SimpleImputer(strategy="median").fit_transform(afq_dataset.X)
afq_dataset.target_cols[0] = group
if impute is not None:
X = SimpleImputer(strategy=impute).fit_transform(afq_dataset.X)

if group is not None:
afq_dataset.target_cols[0] = group

metric = formula.split("~")[0].strip()

tract_data = (
pd.DataFrame(columns=afq_dataset.feature_names, data=X)
Expand All @@ -106,12 +112,13 @@ def node_wise_regression(
)

pvals = np.zeros(tract_data.shape[-1])
pvals_corrected = np.zeros(tract_data.shape[-1])
coefs_default = np.zeros(tract_data.shape[-1])
coefs_treat = np.zeros(tract_data.shape[-1])
cis_default = np.zeros((tract_data.shape[-1], 2))
cis_treat = np.zeros((tract_data.shape[-1], 2))
reject = np.zeros(tract_data.shape[-1], dtype=bool)
fits = {}

# Loop through each node and fit model
for ii, column in enumerate(tract_data.columns):
# fit linear mixed-effects model
Expand All @@ -125,7 +132,6 @@ def node_wise_regression(

model = smf.mixedlm(formula, this, groups=rand_eff)
fit = model.fit()
fits[column] = fit

# fit OLS model
else:
Expand All @@ -135,31 +141,76 @@ def node_wise_regression(

model = OLS.from_formula(formula, this)
fit = model.fit()
fits[column] = fit

fits[ii] = fit
# pull out coefficients, CIs, and p-values from our model
coefs_default[ii] = fit.params.filter(regex="Intercept", axis=0).iloc[0]
coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0]

cis_default[ii] = (
fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values
)
cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values
pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0]

# Correct p-values for multiple comparisons
reject, pval_corrected, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh")
reject_idx = np.where(reject)

tract_dict = {
"tract": tract,
"reference_coefs": coefs_default,
"group_coefs": coefs_treat,
"reference_CI": cis_default,
"group_CI": cis_treat,
"pvals": pvals,
"reject_idx": reject_idx,
"model_fits": fits,
}

return tract_dict

if group is not None:
coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0]

cis_default[ii] = (
fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values
)
cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values
pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0]

# Correct p-values for multiple comparisons
reject, pvals_corrected, _, _ = multipletests(
pvals, alpha=0.05, method="fdr_bh"
)

reject = np.where(reject, 1, 0)

return pd.DataFrame(
{
"reference_coefs": coefs_default,
"group_coefs": coefs_treat,
"reference_CI_lb": cis_default[:, 0],
"reference_CI_ub": cis_default[:, 1],
"group_CI_lb": cis_treat[:, 0],
"group_CI_ub": cis_treat[:, 1],
"pvals": pvals,
"pvals_corrected": pvals_corrected,
"reject_idx": reject,
}
), fits


class RegressionResults(object):
def __init__(self, kwargs):
self.tract = kwargs.get("tract", None)
self.reference_coefs = kwargs.get("reference_coefs", None)
self.group_coefs = kwargs.get("group_coefs", None)
self.reference_ci = kwargs.get("reference_ci", None)
self.group_ci = kwargs.get("group_ci", None)
self.pvals = kwargs.get("pvals", None)
self.pvals_corrected = kwargs.get("pvals_corrected", None)
self.reject_idx = kwargs.get("reject_idx", None)
self.model_fits = kwargs.get("model_fits", None)


class NodeWiseRegression(object):
def __init__(self, formula, lme=False):
self.formula = formula
self.lme = lme

def fit(self, dataset, tracts, group=None, rand_eff="subjectID"):
self.result_ = {}
for tract in tracts:
self.result_[tract] = node_wise_regression(
dataset,
tract,
self.formula,
lme=self.lme,
group=group,
rand_eff=rand_eff,
)
self.is_fitted = True
return self

def predict(self, dataset, tract, metric, group="group", rand_eff="subjectID"):
if not self.is_fitted:
raise ValueError("Model not fitted yet. Please call fit() method first.")
result = self.result_.get(tract, None)
if result is None:
raise ValueError(f"Tract {tract} not found in the fitted model.")
2 changes: 1 addition & 1 deletion afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_from_study(study):
"n_subjects": 48,
"n_features": 4000,
"n_groups": 40,
"target_cols": ["class"],
"target_cols": ["class", "age"],
},
"weston-havens": {
"n_subjects": 77,
Expand Down
43 changes: 43 additions & 0 deletions afqinsight/tests/test_parametric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from afqinsight import AFQDataset
from afqinsight.parametric import NodeWiseRegression, node_wise_regression


def test_node_wise_regression():
# Store results
group_dict = {}
group_age_dict = {}
age_dict = {}

data = AFQDataset.from_study("sarica")
tracts = ["Right Corticospinal", "Right SLF"]
for tract in tracts:
for lme in [True, False]:
# Run different versions of this: with age, without age, only with
# age:

group_dict[tract] = node_wise_regression(
data, tract, "fa ~ C(group)", lme=lme, group="group"
)
group_age_dict[tract] = node_wise_regression(
data, tract, "fa ~ C(group) + age", lme=lme, group="group"
)
age_dict[tract] = node_wise_regression(data, tract, "fa ~ age", lme=lme)

assert group_dict[tract]["pvals"].shape == (100,)
assert group_age_dict[tract]["pvals"].shape == (100,)
assert age_dict[tract]["pvals"].shape == (100,)

assert np.any(group_dict["Right Corticospinal"]["pvals_corrected"] < 0.05)
assert np.all(group_dict["Right SLF"]["pvals_corrected"] > 0.05)
assert np.any(group_age_dict["Right Corticospinal"]["pvals_corrected"] < 0.05)
assert np.all(group_age_dict["Right SLF"]["pvals_corrected"] > 0.05)


def test_NodeWiseRegression():
data = AFQDataset.from_study("sarica")
tracts = ["Left Corticospinal", "Left SLF"]
for lme in [True, False]:
model = NodeWiseRegression("fa ~ C(group) + age", lme=lme)
model.fit(data, tracts, group="group")
3 changes: 2 additions & 1 deletion examples/plot_als_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

X = afqdata.X
y = afqdata.y.astype(float) # SGL expects float targets
is_als = y[:, 0]
groups = afqdata.groups
feature_names = afqdata.feature_names
group_names = afqdata.group_names
Expand Down Expand Up @@ -117,7 +118,7 @@
# scikit-learn functions

scores = cross_validate(
pipe, X, y, cv=5, return_train_score=True, return_estimator=True
pipe, X, is_als, cv=5, return_train_score=True, return_estimator=True
)

# Display results
Expand Down
8 changes: 4 additions & 4 deletions examples/plot_als_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@


# Loop through the data and generate plots
for i, tract in enumerate(tracts):
for ii, tract in enumerate(tracts):
# fit node-wise regression for each tract based on model formula
tract_dict = node_wise_regression(afqdata, tract, "fa", "fa ~ C(group)")
tract_dict = node_wise_regression(afqdata, tract, "fa ~ C(group)", group="group")

row = i // num_cols
col = i % num_cols
row = ii // num_cols
col = ii % num_cols

axes[row][col].set_title(tract)

Expand Down
Loading