From 5299f194e62da14cf3b3e44763b78d9d080a642f Mon Sep 17 00:00:00 2001 From: Thomas George Date: Mon, 7 Oct 2024 10:32:15 +0200 Subject: [PATCH] better pytest for ekfac pow --- tests/test_jacobian_ekfac.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_jacobian_ekfac.py b/tests/test_jacobian_ekfac.py index 3ffc572..1617eb3 100644 --- a/tests/test_jacobian_ekfac.py +++ b/tests/test_jacobian_ekfac.py @@ -85,11 +85,17 @@ def test_pspace_ekfac_vs_direct(): check_tensors(mv_direct, mv_ekfac.get_flat_representation()) # Test pow - M_pow = M_ekfac**2 check_tensors( - M_pow.get_dense_tensor(), + (M_ekfac**2).get_dense_tensor(), torch.mm(M_ekfac.get_dense_tensor(), M_ekfac.get_dense_tensor()), ) + check_tensors( + torch.mm( + (M_ekfac ** (1 / 3)).get_dense_tensor(), + (M_ekfac ** (2 / 3)).get_dense_tensor(), + ), + M_ekfac.get_dense_tensor(), + ) # Test inverse regul = 1e-5