From 485c953f42d3985cd3cbf66b5614d128f84a8f85 Mon Sep 17 00:00:00 2001 From: Malte Londschien <61679398+mlondschien@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:54:08 +0200 Subject: [PATCH] Use `scipy.linalg.lstsq` instead of `np.linalg.lstsq` (#83) * Use scipy.linalg.lstsq * Remove line_profielr line. * gelsy does not support empty target. --- ivmodels/utils.py | 37 +++++++++++++++++-------------------- tests/test_utils.py | 5 +++++ 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/ivmodels/utils.py b/ivmodels/utils.py index 70aa349..5b2e2a2 100644 --- a/ivmodels/utils.py +++ b/ivmodels/utils.py @@ -1,4 +1,5 @@ import numpy as np +import scipy try: import pandas as pd @@ -36,13 +37,24 @@ def proj(Z, *args): raise ValueError(f"Shape mismatch: Z.shape={Z.shape}, f.shape={f.shape}.") if len(args) == 1: - return np.dot(Z, np.linalg.lstsq(Z, args[0], rcond=None)[0]) + # The gelsy driver raises in this case - we handle it separately + if len(args[0].shape) == 2 and args[0].shape[1] == 0: + return np.zeros_like(args[0]) + + return np.dot( + Z, scipy.linalg.lstsq(Z, args[0], cond=None, lapack_driver="gelsy")[0] + ) csum = np.cumsum([f.shape[1] if len(f.shape) == 2 else 1 for f in args]) csum = [0] + csum.tolist() fs = np.hstack([f.reshape(Z.shape[0], -1) for f in args]) - fs = np.dot(Z, np.linalg.lstsq(Z, fs, rcond=None)[0]) + + if fs.shape[1] == 0: + # The gelsy driver raises in this case - we handle it separately + return (*(np.zeros_like(f) for f in args),) + + fs = np.dot(Z, scipy.linalg.lstsq(Z, fs, cond=None, lapack_driver="gelsy")[0]) return ( *(fs[:, i:j].reshape(f.shape) for i, j, f in zip(csum[:-1], csum[1:], args)), @@ -68,26 +80,11 @@ def oproj(Z, *args): if Z is None: return (*args,) - for f in args: - if len(f.shape) > 2: - raise ValueError( - f"*args should have shapes (n, d_f) or (n,). Got {f.shape}." - ) - if f.shape[0] != Z.shape[0]: - raise ValueError(f"Shape mismatch: Z.shape={Z.shape}, f.shape={f.shape}.") - if len(args) == 1: - return args[0] - np.dot(Z, np.linalg.lstsq(Z, args[0], rcond=None)[0]) - - csum = np.cumsum([f.shape[1] if len(f.shape) == 2 else 1 for f in args]) - csum = [0] + csum.tolist() - - fs = np.hstack([f.reshape(Z.shape[0], -1) for f in args]) - fs = fs - np.dot(Z, np.linalg.lstsq(Z, fs, rcond=None)[0]) + return args[0] - proj(Z, args[0]) - return ( - *(fs[:, i:j].reshape(f.shape) for i, j, f in zip(csum[:-1], csum[1:], args)), - ) + else: + return (*(x - x_proj for x, x_proj in zip(args, proj(Z, *args))),) def to_numpy(x): diff --git a/tests/test_utils.py b/tests/test_utils.py index 6adef43..b584530 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,6 +36,11 @@ def test_proj_multiple_args(): ) assert np.allclose(proj(X, z1), X @ np.linalg.inv(X.T @ X) @ X.T @ z1) assert np.allclose(proj(X, z2), X @ np.linalg.inv(X.T @ X) @ X.T @ z2) + assert np.allclose( + proj(X, z2, z2), + X @ np.linalg.inv(X.T @ X) @ X.T @ z2, + X @ np.linalg.inv(X.T @ X) @ X.T @ z2, + ) assert np.allclose(proj(X, z3), X @ np.linalg.inv(X.T @ X) @ X.T @ z3)