Skip to content

Commit

Permalink
Make jac/hess individually optional.
Browse files Browse the repository at this point in the history
  • Loading branch information
mlondschien committed Jul 9, 2024
1 parent bc7f67f commit ee69bda
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 59 deletions.
113 changes: 62 additions & 51 deletions ivmodels/tests/lagrange_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def liml(self, beta=None):

return -eigval[1:, 0] / eigval[0, 0]

def derivative(self, beta, gamma=None, jac_and_hess=True):
def derivative(self, beta, gamma=None, jac=True, hess=True):
"""Return LM and derivative of LM at beta, gamma w.r.t. (beta, gamma)."""
if gamma is not None:
one_beta_gamma = np.hstack(([1], -beta.flatten(), -gamma.flatten()))
Expand All @@ -131,67 +131,78 @@ def derivative(self, beta, gamma=None, jac_and_hess=True):

St_proj = self.yS_proj[:, 1:] - np.outer(residuals_proj, Sigma)

if not jac_and_hess:
if not jac: # not jac -> not hess
residuals_proj_St = proj(St_proj, residuals_proj)
return self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat
return (
self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat,
None,
None,
)

residuals = self.yS @ one_beta_gamma
St = self.yS[:, 1:] - np.outer(residuals, Sigma)
St_orth = St - St_proj

mat = St_proj.T @ St_proj
# cond = np.linalg.cond(mat)
# if cond > 1e12:
# mat += 1e-6 * np.eye(mat.shape[0])

solved = np.linalg.solve(mat, St_proj.T @ residuals_proj)
# mat,
# np.hstack(
# [
# St_proj.T @ residuals_proj.reshape(-1, 1),
# St_orth.T @ St[:, self.mx :],
# ]
# ),
# )
residuals_proj_St = St_proj @ solved
if hess:
cond = np.linalg.cond(mat)
if cond > 1e8:
mat += 1e-8 * np.eye(mat.shape[0])

solved = np.linalg.solve(
mat,
np.hstack(
[
St_proj.T @ residuals_proj.reshape(-1, 1),
St_orth.T @ St[:, self.mx :],
]
),
)
else:
solved = np.linalg.solve(mat, St_proj.T @ residuals_proj.reshape(-1, 1))

residuals_proj_St = St_proj @ solved[:, 0]

ar = residuals_proj.T @ residuals_proj / sigma_hat
lm = residuals_proj_St.T @ residuals_proj_St / sigma_hat
kappa = ar - lm

first_term = -St_proj[:, self.mx :].T @ residuals_proj
second_term = St_orth[:, self.mx :].T @ St @ solved
# S = self.yS[:, 1:]
# S_proj = self.yS_proj[:, 1:]
# S_orth = S - S_proj
second_term = St_orth[:, self.mx :].T @ St @ solved[:, 0]

d_lm = 2 * (first_term + kappa * second_term) / sigma_hat

# dd_lm = (
# 2
# * (
# -3 * kappa * np.outer(second_term, second_term) / sigma_hat
# + kappa**2 * St_orth[:, self.mx :].T @ St_orth @ solved[:, 1:]
# - kappa * St_orth[:, self.mx :].T @ St_orth[:, self.mx :]
# - kappa
# * St_orth[:, self.mx :].T
# @ St_orth
# @ np.outer(solved[:, 0], Sigma[self.mx :])
# + St[:, self.mx :].T
# @ (S_proj[:, self.mx :] - ar * S_orth[:, self.mx :])
# - np.outer(
# Sigma[self.mx :],
# (St_proj - kappa * St_orth)[:, self.mx :].T @ St @ solved[:, 0],
# )
# + 2
# * kappa
# * np.outer(S_orth[:, self.mx :].T @ St @ solved[:, 0], Sigma[self.mx :])
# - 2 * np.outer(St_proj[:, self.mx :].T @ residuals, Sigma[self.mx :])
# )
# / sigma_hat
# )

return (self.dof * lm.item(), self.dof * d_lm.flatten(), None)
if not hess:
return (self.dof * lm, self.dof * d_lm, None)

S = self.yS[:, 1:]
S_proj = self.yS_proj[:, 1:]
S_orth = S - S_proj
dd_lm = (
2
* (
-3 * kappa * np.outer(second_term, second_term) / sigma_hat
+ kappa**2 * St_orth[:, self.mx :].T @ St_orth @ solved[:, 1:]
- kappa * St_orth[:, self.mx :].T @ St_orth[:, self.mx :]
- kappa
* St_orth[:, self.mx :].T
@ St_orth
@ np.outer(solved[:, 0], Sigma[self.mx :])
+ St[:, self.mx :].T
@ (S_proj[:, self.mx :] - ar * S_orth[:, self.mx :])
- np.outer(
Sigma[self.mx :],
(St_proj - kappa * St_orth)[:, self.mx :].T @ St @ solved[:, 0],
)
+ 2
* kappa
* np.outer(S_orth[:, self.mx :].T @ St @ solved[:, 0], Sigma[self.mx :])
- 2 * np.outer(St_proj[:, self.mx :].T @ residuals, Sigma[self.mx :])
)
/ sigma_hat
)

return (self.dof * lm.item(), self.dof * d_lm.flatten(), self.dof * dd_lm)

def lm(self, beta):
"""
Expand All @@ -203,26 +214,26 @@ def lm(self, beta):
beta = np.array([[beta]])

if self.mw == 0:
return self.derivative(beta, jac_and_hess=False)
return self.derivative(beta, jac=False, hess=False)[0]

gamma_0 = self.liml(beta=beta)

def _derivative(gamma):
result = self.derivative(beta, gamma, jac_and_hess=True)
result = self.derivative(beta, gamma, jac=True, hess=False)
return (result[0], result[1], result[2])

objective = MemoizeJacHess(_derivative)
jac = objective.derivative
# hess = objective.hessian
hess = objective.hessian

res1 = scipy.optimize.minimize(
objective, jac=jac, hess=None, x0=gamma_0, method="bfgs"
objective, jac=jac, hess=hess, x0=gamma_0, method="bfgs"
)

res2 = scipy.optimize.minimize(
objective,
jac=jac,
hess=None,
hess=hess,
method="bfgs",
x0=np.zeros_like(gamma_0),
)
Expand Down
30 changes: 22 additions & 8 deletions tests/tests/test_lagrange_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,37 @@ def test_lm_gradient(n, mx, mw, k):
beta = rng.normal(0, 1, mx)
gamma = rng.normal(0, 1, mw)

grad_approx = scipy.optimize.approx_fprime(
grad_approx1 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g)[0],
lambda g: lm.derivative(beta=beta, gamma=g, jac=False, hess=False)[0],
1e-6,
)
grad = lm.derivative(beta, gamma)[1]
grad_approx2 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[0],
1e-6,
)
grad1 = lm.derivative(beta, gamma, jac=True, hess=False)[1]
grad2 = lm.derivative(beta, gamma, jac=True, hess=True)[1]

assert np.allclose(grad, grad_approx, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad_approx1, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad_approx2, rtol=5e-4, atol=5e-4)
assert np.allclose(grad1, grad2, rtol=5e-4, atol=5e-4)

hess_approx = scipy.optimize.approx_fprime(
hess_approx1 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[1],
1e-6,
)
hess_approx2 = scipy.optimize.approx_fprime(
gamma,
lambda g: lm.derivative(beta=beta, gamma=g)[1],
lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=False)[1],
1e-6,
)
hess = lm.derivative(beta, gamma)[2]
hess = lm.derivative(beta, gamma, jac=True, hess=True)[2]

assert np.allclose(hess, hess_approx, rtol=5e-5, atol=5e-5)
assert np.allclose(hess, hess_approx1, rtol=5e-5, atol=5e-5)
assert np.allclose(hess, hess_approx2, rtol=5e-5, atol=5e-5)


@pytest.mark.parametrize(
Expand Down

0 comments on commit ee69bda

Please sign in to comment.