diff --git a/ivmodels/tests/lagrange_multiplier.py b/ivmodels/tests/lagrange_multiplier.py index 5343fb4..64be1a4 100644 --- a/ivmodels/tests/lagrange_multiplier.py +++ b/ivmodels/tests/lagrange_multiplier.py @@ -116,7 +116,7 @@ def liml(self, beta=None): return -eigval[1:, 0] / eigval[0, 0] - def derivative(self, beta, gamma=None, jac_and_hess=True): + def derivative(self, beta, gamma=None, jac=True, hess=True): """Return LM and derivative of LM at beta, gamma w.r.t. (beta, gamma).""" if gamma is not None: one_beta_gamma = np.hstack(([1], -beta.flatten(), -gamma.flatten())) @@ -131,67 +131,78 @@ def derivative(self, beta, gamma=None, jac_and_hess=True): St_proj = self.yS_proj[:, 1:] - np.outer(residuals_proj, Sigma) - if not jac_and_hess: + if not jac: # not jac -> not hess residuals_proj_St = proj(St_proj, residuals_proj) - return self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat + return ( + self.dof * residuals_proj_St.T @ residuals_proj_St / sigma_hat, + None, + None, + ) residuals = self.yS @ one_beta_gamma St = self.yS[:, 1:] - np.outer(residuals, Sigma) St_orth = St - St_proj mat = St_proj.T @ St_proj - # cond = np.linalg.cond(mat) - # if cond > 1e12: - # mat += 1e-6 * np.eye(mat.shape[0]) - - solved = np.linalg.solve(mat, St_proj.T @ residuals_proj) - # mat, - # np.hstack( - # [ - # St_proj.T @ residuals_proj.reshape(-1, 1), - # St_orth.T @ St[:, self.mx :], - # ] - # ), - # ) - residuals_proj_St = St_proj @ solved + if hess: + cond = np.linalg.cond(mat) + if cond > 1e8: + mat += 1e-8 * np.eye(mat.shape[0]) + + solved = np.linalg.solve( + mat, + np.hstack( + [ + St_proj.T @ residuals_proj.reshape(-1, 1), + St_orth.T @ St[:, self.mx :], + ] + ), + ) + else: + solved = np.linalg.solve(mat, St_proj.T @ residuals_proj.reshape(-1, 1)) + + residuals_proj_St = St_proj @ solved[:, 0] ar = residuals_proj.T @ residuals_proj / sigma_hat lm = residuals_proj_St.T @ residuals_proj_St / sigma_hat kappa = ar - lm first_term = -St_proj[:, self.mx :].T @ residuals_proj - second_term = St_orth[:, self.mx :].T @ St @ solved - # S = self.yS[:, 1:] - # S_proj = self.yS_proj[:, 1:] - # S_orth = S - S_proj + second_term = St_orth[:, self.mx :].T @ St @ solved[:, 0] d_lm = 2 * (first_term + kappa * second_term) / sigma_hat - # dd_lm = ( - # 2 - # * ( - # -3 * kappa * np.outer(second_term, second_term) / sigma_hat - # + kappa**2 * St_orth[:, self.mx :].T @ St_orth @ solved[:, 1:] - # - kappa * St_orth[:, self.mx :].T @ St_orth[:, self.mx :] - # - kappa - # * St_orth[:, self.mx :].T - # @ St_orth - # @ np.outer(solved[:, 0], Sigma[self.mx :]) - # + St[:, self.mx :].T - # @ (S_proj[:, self.mx :] - ar * S_orth[:, self.mx :]) - # - np.outer( - # Sigma[self.mx :], - # (St_proj - kappa * St_orth)[:, self.mx :].T @ St @ solved[:, 0], - # ) - # + 2 - # * kappa - # * np.outer(S_orth[:, self.mx :].T @ St @ solved[:, 0], Sigma[self.mx :]) - # - 2 * np.outer(St_proj[:, self.mx :].T @ residuals, Sigma[self.mx :]) - # ) - # / sigma_hat - # ) - - return (self.dof * lm.item(), self.dof * d_lm.flatten(), None) + if not hess: + return (self.dof * lm, self.dof * d_lm, None) + + S = self.yS[:, 1:] + S_proj = self.yS_proj[:, 1:] + S_orth = S - S_proj + dd_lm = ( + 2 + * ( + -3 * kappa * np.outer(second_term, second_term) / sigma_hat + + kappa**2 * St_orth[:, self.mx :].T @ St_orth @ solved[:, 1:] + - kappa * St_orth[:, self.mx :].T @ St_orth[:, self.mx :] + - kappa + * St_orth[:, self.mx :].T + @ St_orth + @ np.outer(solved[:, 0], Sigma[self.mx :]) + + St[:, self.mx :].T + @ (S_proj[:, self.mx :] - ar * S_orth[:, self.mx :]) + - np.outer( + Sigma[self.mx :], + (St_proj - kappa * St_orth)[:, self.mx :].T @ St @ solved[:, 0], + ) + + 2 + * kappa + * np.outer(S_orth[:, self.mx :].T @ St @ solved[:, 0], Sigma[self.mx :]) + - 2 * np.outer(St_proj[:, self.mx :].T @ residuals, Sigma[self.mx :]) + ) + / sigma_hat + ) + + return (self.dof * lm.item(), self.dof * d_lm.flatten(), self.dof * dd_lm) def lm(self, beta): """ @@ -203,26 +214,26 @@ def lm(self, beta): beta = np.array([[beta]]) if self.mw == 0: - return self.derivative(beta, jac_and_hess=False) + return self.derivative(beta, jac=False, hess=False)[0] gamma_0 = self.liml(beta=beta) def _derivative(gamma): - result = self.derivative(beta, gamma, jac_and_hess=True) + result = self.derivative(beta, gamma, jac=True, hess=False) return (result[0], result[1], result[2]) objective = MemoizeJacHess(_derivative) jac = objective.derivative - # hess = objective.hessian + hess = objective.hessian res1 = scipy.optimize.minimize( - objective, jac=jac, hess=None, x0=gamma_0, method="bfgs" + objective, jac=jac, hess=hess, x0=gamma_0, method="bfgs" ) res2 = scipy.optimize.minimize( objective, jac=jac, - hess=None, + hess=hess, method="bfgs", x0=np.zeros_like(gamma_0), ) diff --git a/tests/tests/test_lagrange_multiplier.py b/tests/tests/test_lagrange_multiplier.py index fb77955..83169da 100644 --- a/tests/tests/test_lagrange_multiplier.py +++ b/tests/tests/test_lagrange_multiplier.py @@ -55,23 +55,37 @@ def test_lm_gradient(n, mx, mw, k): beta = rng.normal(0, 1, mx) gamma = rng.normal(0, 1, mw) - grad_approx = scipy.optimize.approx_fprime( + grad_approx1 = scipy.optimize.approx_fprime( gamma, - lambda g: lm.derivative(beta=beta, gamma=g)[0], + lambda g: lm.derivative(beta=beta, gamma=g, jac=False, hess=False)[0], 1e-6, ) - grad = lm.derivative(beta, gamma)[1] + grad_approx2 = scipy.optimize.approx_fprime( + gamma, + lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[0], + 1e-6, + ) + grad1 = lm.derivative(beta, gamma, jac=True, hess=False)[1] + grad2 = lm.derivative(beta, gamma, jac=True, hess=True)[1] - assert np.allclose(grad, grad_approx, rtol=5e-4, atol=5e-4) + assert np.allclose(grad1, grad_approx1, rtol=5e-4, atol=5e-4) + assert np.allclose(grad1, grad_approx2, rtol=5e-4, atol=5e-4) + assert np.allclose(grad1, grad2, rtol=5e-4, atol=5e-4) - hess_approx = scipy.optimize.approx_fprime( + hess_approx1 = scipy.optimize.approx_fprime( + gamma, + lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=True)[1], + 1e-6, + ) + hess_approx2 = scipy.optimize.approx_fprime( gamma, - lambda g: lm.derivative(beta=beta, gamma=g)[1], + lambda g: lm.derivative(beta=beta, gamma=g, jac=True, hess=False)[1], 1e-6, ) - hess = lm.derivative(beta, gamma)[2] + hess = lm.derivative(beta, gamma, jac=True, hess=True)[2] - assert np.allclose(hess, hess_approx, rtol=5e-5, atol=5e-5) + assert np.allclose(hess, hess_approx1, rtol=5e-5, atol=5e-5) + assert np.allclose(hess, hess_approx2, rtol=5e-5, atol=5e-5) @pytest.mark.parametrize(