Skip to content

Commit

Permalink
Speed up lm test (#93)
Browse files Browse the repository at this point in the history
* Use hessian.

* Don't compute hessian.

* bfgs

* Make jac/hess individually optional.

* Make optimizer and gamma_0 an argument.

* Fix test.

* Use pinv sometimes.

* same for hess.
  • Loading branch information
mlondschien authored Jul 9, 2024
1 parent 45abd96 commit fa32c3d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 40 deletions.
114 changes: 82 additions & 32 deletions ivmodels/tests/lagrange_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,19 @@ class _LM:
Projection of ``W`` onto the column space of ``Z``.
"""

def __init__(self, X, y, W, dof, Z=None, X_proj=None, y_proj=None, W_proj=None):
def __init__(
self,
X,
y,
W,
dof,
Z=None,
X_proj=None,
y_proj=None,
W_proj=None,
optimizer="bfgs",
gamma_0=None,
):

self.X = X
self.y = y.reshape(-1, 1)
Expand Down Expand Up @@ -87,6 +99,9 @@ def __init__(self, X, y, W, dof, Z=None, X_proj=None, y_proj=None, W_proj=None):
# for liml
self.yS_proj_at_yS = self.yS_proj.T @ self.yS_proj

self.optimizer = optimizer
self.gamma_0 = ["liml"] if gamma_0 is None else gamma_0

def liml(self, beta=None):
"""
Efficiently compute the LIML.
Expand Down Expand Up @@ -116,7 +131,7 @@ def liml(self, beta=None):

return -eigval[1:, 0] / eigval[0, 0]

def derivative(self, beta, gamma=None, jac_and_hess=True):
def derivative(self, beta, gamma=None, jac=True, hess=True):
"""Return LM and derivative of LM at beta, gamma w.r.t. (beta, gamma)."""
if gamma is not None:
one_beta_gamma = np.hstack(([1], -beta.flatten(), -gamma.flatten()))
Expand All @@ -131,28 +146,44 @@ def derivative(self, beta, gamma=None, jac_and_hess=True):

St_proj = self.yS_proj[:, 1:] - np.outer(residuals_proj, Sigma)

if not jac_and_hess:
if not jac: # not jac -> not hess
residuals_proj_St = proj(St_proj, residuals_proj)
return self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat
return (
self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat,
None,
None,
)

residuals = self.yS @ one_beta_gamma
St = self.yS[:, 1:] - np.outer(residuals, Sigma)
St_orth = St - St_proj

mat = St_proj.T @ St_proj
cond = np.linalg.cond(mat)
if cond > 1e12:
mat += 1e-6 * np.eye(mat.shape[0])

solved = np.linalg.solve(
mat,
np.hstack(
if hess:
f = np.hstack(
[
St_proj.T @ residuals_proj.reshape(-1, 1),
St_orth.T @ St[:, self.mx :],
]
),
)
)
if cond > 1e8:
solved = np.linalg.pinv(mat) @ f
else:
solved = np.linalg.solve(mat, f)

else:
# If mat is well conditioned, both should be equivalent, but the pinv
# solution is defined even if mat is singular. In theory, solve should be
# faster. In practice, not so clear. The lstsq solution tends to be slower.
if cond > 1e8:
solved = np.linalg.pinv(mat) @ St_proj.T @ residuals_proj.reshape(-1, 1)
else:
# solved = scipy.linalg.lstsq(St_proj, residuals_proj.reshape(-1, 1), cond=None, lapack_driver="gelsy")[0]
# solved = np.linalg.pinv(mat) @ St_proj.T @ residuals_proj.reshape(-1, 1)
solved = np.linalg.solve(mat, St_proj.T @ residuals_proj.reshape(-1, 1))

residuals_proj_St = St_proj @ solved[:, 0]

ar = residuals_proj.T @ residuals_proj / sigma_hat
Expand All @@ -161,12 +192,15 @@ def derivative(self, beta, gamma=None, jac_and_hess=True):

first_term = -St_proj[:, self.mx :].T @ residuals_proj
second_term = St_orth[:, self.mx :].T @ St @ solved[:, 0]
S = self.yS[:, 1:]
S_proj = self.yS_proj[:, 1:]
S_orth = S - S_proj

d_lm = 2 * (first_term + kappa * second_term) / sigma_hat

if not hess:
return (self.dof * lm, self.dof * d_lm, None)

S = self.yS[:, 1:]
S_proj = self.yS_proj[:, 1:]
S_orth = S - S_proj
dd_lm = (
2
* (
Expand Down Expand Up @@ -203,34 +237,50 @@ def lm(self, beta):
beta = np.array([[beta]])

if self.mw == 0:
return self.derivative(beta, jac_and_hess=False)
return self.derivative(beta, jac=False, hess=False)[0]

gamma_0 = self.liml(beta=beta)
methods_with_hessian = [
"newton-cg",
"dogleg",
"trust-ncg",
"trust-krylov",
"trust-exact",
]
hess = self.optimizer.lower() in methods_with_hessian

def _derivative(gamma):
result = self.derivative(beta, gamma, jac_and_hess=True)
return (result[0], result[1], result[2])
result = self.derivative(beta, gamma, jac=True, hess=hess)
return result

objective = MemoizeJacHess(_derivative)
jac = objective.derivative
# hess = objective.hessian

res1 = scipy.optimize.minimize(
objective, jac=jac, hess=None, x0=gamma_0, method="newton-cg"
hess = (
objective.hessian
if self.optimizer.lower() in methods_with_hessian
else None
)

res2 = scipy.optimize.minimize(
objective,
jac=jac,
hess=None,
method="newton-cg",
x0=np.zeros_like(gamma_0),
)
results = []
for g in self.gamma_0:
if g == "liml":
gamma_0 = self.liml(beta=beta)
elif g == "zero":
gamma_0 = np.zeros(self.mw)
else:
raise ValueError(f"unknown gamma_0: {g}")

return np.min([res1.fun, res2.fun])
results.append(
scipy.optimize.minimize(
objective, jac=jac, hess=hess, x0=gamma_0, method=self.optimizer
)
)

return np.min([r.fun for r in results])

def lagrange_multiplier_test(Z, X, y, beta, W=None, C=None, fit_intercept=True):

def lagrange_multiplier_test(
Z, X, y, beta, W=None, C=None, fit_intercept=True, **kwargs
):
"""
Perform the Lagrange multiplier test for ``beta`` by :cite:t:`kleibergen2002pivotal`.
Expand Down Expand Up @@ -298,7 +348,7 @@ def lagrange_multiplier_test(Z, X, y, beta, W=None, C=None, fit_intercept=True):
X, y, Z, W = oproj(C, X, y, Z, W)

if W.shape[1] > 0:
statistic = _LM(X=X, y=y, W=W, Z=Z, dof=n - k - C.shape[1]).lm(beta)
statistic = _LM(X=X, y=y, W=W, Z=Z, dof=n - k - C.shape[1], **kwargs).lm(beta)

p_value = 1 - scipy.stats.chi2.cdf(statistic, df=mx)

Expand Down
31 changes: 23 additions & 8 deletions tests/tests/test_lagrange_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test__LM__init__(n, mx, mw, k):
[
np.all(np.isclose(lm1.__dict__[k], lm2.__dict__[k]))
for k in lm1.__dict__.keys()
if k not in ["optimizer", "gamma_0"]
]
)

Expand All @@ -55,23 +56,37 @@ def test_lm_gradient(n, mx, mw, k):
beta = rng.normal(0, 1, mx)
gamma = rng.normal(0, 1, mw)

grad_approx = scipy.optimize.approx_fprime(
grad_approx1 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g)[0],
lambda g: lm.derivative(beta=beta, gamma=g, jac=False, hess=False)[0],
1e-6,
)
grad = lm.derivative(beta, gamma)[1]
grad_approx2 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[0],
1e-6,
)
grad1 = lm.derivative(beta, gamma, jac=True, hess=False)[1]
grad2 = lm.derivative(beta, gamma, jac=True, hess=True)[1]

assert np.allclose(grad, grad_approx, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad_approx1, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad_approx2, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad2, rtol=5e-4, atol=5e-4)

hess_approx = scipy.optimize.approx_fprime(
hess_approx1 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[1],
1e-6,
)
hess_approx2 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g)[1],
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=False)[1],
1e-6,
)
hess = lm.derivative(beta, gamma)[2]
hess = lm.derivative(beta, gamma, jac=True, hess=True)[2]

assert np.allclose(hess, hess_approx, rtol=5e-5, atol=5e-5)
assert np.allclose(hess, hess_approx1, rtol=5e-5, atol=5e-5)
assert np.allclose(hess, hess_approx2, rtol=5e-5, atol=5e-5)


@pytest.mark.parametrize(
Expand Down

0 comments on commit fa32c3d

Please sign in to comment.