Skip to content

Commit

Permalink
Merge branch 'main' of github.com:aangelopoulos/ppi_py
Browse files Browse the repository at this point in the history
  • Loading branch information
aangelopoulos committed Oct 6, 2023
2 parents ffe213f + 7b5240d commit e3d5a24
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions ppi_py/ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def ppi_logistic_pointestimate(
theta -= step_size * grad
return theta


def ppi_logistic_ci(
X,
Y,
Expand All @@ -411,7 +412,7 @@ def ppi_logistic_ci(
alpha=0.1,
step_size=1e-3, # Optimizer step size
grad_tol=5e-16, # Optimizer grad tol
alternative='two-sided'
alternative="two-sided",
):
"""Computes the prediction-powered confidence interval for the logistic regression coefficients.
Expand Down Expand Up @@ -443,13 +444,19 @@ def ppi_logistic_ci(
grad_tol=grad_tol,
)

mu_til = expit(X_unlabeled@ppi_pointest)
mu_til = expit(X_unlabeled @ ppi_pointest)

Hessian = np.zeros((d,d))
Hessian = np.zeros((d, d))
grads_til = np.zeros(X_unlabeled.shape)
for i in range(N):
Hessian += 1/N * mu_til[i] * (1-mu_til[i]) * np.outer(X_unlabeled[i], X_unlabeled[i])
grads_til[i,:] = X_unlabeled[i,:]*(mu_til[i] - Yhat_unlabeled[i])
Hessian += (
1
/ N
* mu_til[i]
* (1 - mu_til[i])
* np.outer(X_unlabeled[i], X_unlabeled[i])
)
grads_til[i, :] = X_unlabeled[i, :] * (mu_til[i] - Yhat_unlabeled[i])

inv_Hessian = np.linalg.inv(Hessian)
var_unlabeled = np.cov(grads_til.T)
Expand All @@ -458,9 +465,14 @@ def ppi_logistic_ci(
grad_diff = np.diag(pred_error) @ X
var = np.cov(grad_diff.T)

Sigma_hat = inv_Hessian @ (n/N * var_unlabeled + var) @ inv_Hessian
Sigma_hat = inv_Hessian @ (n / N * var_unlabeled + var) @ inv_Hessian

return _zconfint_generic(ppi_pointest, np.sqrt(np.diag(Sigma_hat)/n), alpha=alpha, alternative=alternative)
return _zconfint_generic(
ppi_pointest,
np.sqrt(np.diag(Sigma_hat) / n),
alpha=alpha,
alternative=alternative,
)


"""
Expand Down

0 comments on commit e3d5a24

Please sign in to comment.