diff --git a/tests/test_utils.py b/tests/test_utils.py index b584530..229d9da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,17 +7,20 @@ def test_proj(): rng = np.random.RandomState(0) X = rng.normal(0, 1, (100, 2)) - Z = rng.normal(0, 1, (100, 5)) + Z = rng.normal(0, 1, (100, 4)) @ rng.normal(0, 1, (4, 6)) - assert np.allclose(proj(Z, X), Z @ np.linalg.inv(Z.T @ Z) @ Z.T @ X) + assert np.allclose(proj(Z, X), Z @ np.linalg.pinv(Z.T @ Z) @ Z.T @ X) + assert np.allclose(proj(Z, X), proj(Z, proj(Z, X))) + assert np.allclose(proj(proj(Z, X), X), proj(Z, X)) def test_oproj(): rng = np.random.RandomState(0) X = rng.normal(0, 1, (100, 2)) - Z = rng.normal(0, 1, (100, 5)) + Z = rng.normal(0, 1, (100, 4)) @ rng.normal(0, 1, (4, 6)) assert np.allclose(X - proj(Z, X), oproj(Z, X)) + assert np.allclose(oproj(Z, X), oproj(Z, oproj(Z, X))) def test_proj_multiple_args():