diff --git a/tests/test_jacobian_ekfac.py b/tests/test_jacobian_ekfac.py index 3ffc572..1617eb3 100644 --- a/tests/test_jacobian_ekfac.py +++ b/tests/test_jacobian_ekfac.py @@ -85,11 +85,17 @@ def test_pspace_ekfac_vs_direct(): check_tensors(mv_direct, mv_ekfac.get_flat_representation()) # Test pow - M_pow = M_ekfac**2 check_tensors( - M_pow.get_dense_tensor(), + (M_ekfac**2).get_dense_tensor(), torch.mm(M_ekfac.get_dense_tensor(), M_ekfac.get_dense_tensor()), ) + check_tensors( + torch.mm( + (M_ekfac ** (1 / 3)).get_dense_tensor(), + (M_ekfac ** (2 / 3)).get_dense_tensor(), + ), + M_ekfac.get_dense_tensor(), + ) # Test inverse regul = 1e-5