Skip to content

Commit

Permalink
Merge pull request #23 from sdaza/dev
Browse files Browse the repository at this point in the history
regression covariates
  • Loading branch information
sdaza authored Dec 12, 2024
2 parents d732f6e + 8d672c3 commit f3f596c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
# Install flake8
pip install flake8
# Run flake8
flake8 . --ignore=E501,F401,F403,F405,W504
flake8 . --ignore=E501,F401,F403,F405,W504,E125
- name: Run tests
run: |
Expand Down
31 changes: 27 additions & 4 deletions experiment_utils/experiment_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
instrument_col: str = None,
alpha: float = 0.05,
regression_covariates: List = None,
assess_overlap=False):
assess_overlap=False):

"""
Initialize ExperimentAnalyzer
Expand Down Expand Up @@ -95,6 +95,10 @@ def __check_input(self):
if self.data.isEmpty():
log_and_raise_error(self.logger, "Dataframe is empty!")

# impute covariates from regression covariates
if (len(self.covariates) == 0) & (len(self.regression_covariates) > 0):
self.covariates = self.regression_covariates

# regression covariates has to be a subset of covariates
if len(self.regression_covariates) > 0:
if not set(self.regression_covariates).issubset(set(self.covariates)):
Expand Down Expand Up @@ -166,6 +170,25 @@ def standardize_covariates(self, data: pd.DataFrame, covariates: List[str]) -> p
data[f"z_{covariate}"] = (data[covariate] - data[covariate].mean()) / data[covariate].std()
return data

def __create_formula(self, outcome_variable, type: str = 'regression'):
"""
Create formula for final regression model
"""

formula_dict = {
'regression': f"{outcome_variable} ~ 1 + {self.treatment_col}",
'iv': f"{outcome_variable} ~ 1 + [{self.treatment_col} ~ {self.instrument_col}]"
}
reg_covs = list(set(self.final_covariates) & set(self.regression_covariates))

if len(reg_covs) > 0:
zreg_covs = [f"z_{cov}" for cov in reg_covs]
formula = formula_dict[type] + ' + ' + ' + '.join(zreg_covs)
else:
formula = formula_dict[type]

return formula

def linear_regression(self, data: pd.DataFrame, outcome_variable: str) -> Dict:
"""
Runs a linear regression of the outcome variable on the treatment variable.
Expand All @@ -183,7 +206,7 @@ def linear_regression(self, data: pd.DataFrame, outcome_variable: str) -> Dict:
Regression results
"""

formula = f"{outcome_variable} ~ {self.treatment_col}"
formula = self.__create_formula(outcome_variable=outcome_variable)
model = smf.ols(formula, data=data)
results = model.fit(cov_type="HC3")

Expand Down Expand Up @@ -223,7 +246,7 @@ def weighted_least_squares(self, data: pd.DataFrame, outcome_variable: str) -> D
Regression results
"""

formula = f"{outcome_variable} ~ 1 + {self.treatment_col}"
formula = self.__create_formula(outcome_variable=outcome_variable)
model = smf.wls(
formula,
data=data,
Expand Down Expand Up @@ -270,7 +293,7 @@ def iv_regression(self, data: pd.DataFrame, outcome_variable: str) -> Dict:
if not self.instrument_col:
log_and_raise_error(self.logger, "Instrument column must be specified for IV adjustment")

formula = f"{outcome_variable} ~ 1 + [{self.treatment_col} ~ {self.instrument_col}]"
formula = self.__create_formula(outcome_variable=outcome_variable, type='iv')
model = IV2SLS.from_formula(formula, data)
results = model.fit(cov_type='robust')

Expand Down
22 changes: 22 additions & 0 deletions tests/test_experiment_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,28 @@ def test_no_covariates(sample_data):
pytest.fail(f" raised an exception: {e}")


def test_regression_covariates(sample_data):
"""Test get_effects regression covariates"""
outcomes = "conversion"
treatment_col = "treatment"
experiment_identifier = "experiment"
regression_covariates = "baseline_conversion"

analyzer = ExperimentAnalyzer(
data=sample_data,
outcomes=outcomes,
treatment_col=treatment_col,
experiment_identifier=experiment_identifier,
regression_covariates=regression_covariates)

try:
analyzer.get_effects()
analyzer.results
assert True
except Exception as e:
pytest.fail(f" raised an exception: {e}")


def test_no_adjustment(sample_data):
"""Test get_effects no adjustments"""
outcomes = "conversion"
Expand Down

0 comments on commit f3f596c

Please sign in to comment.