Skip to content

Commit

Permalink
Merge pull request #88 from tfjgeorge/ekfac_pw
Browse files Browse the repository at this point in the history
better pytest for ekfac pow
  • Loading branch information
tfjgeorge authored Oct 7, 2024
2 parents 41ff0d4 + 5299f19 commit c03985d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/test_jacobian_ekfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ def test_pspace_ekfac_vs_direct():
check_tensors(mv_direct, mv_ekfac.get_flat_representation())

# Test pow
M_pow = M_ekfac**2
check_tensors(
M_pow.get_dense_tensor(),
(M_ekfac**2).get_dense_tensor(),
torch.mm(M_ekfac.get_dense_tensor(), M_ekfac.get_dense_tensor()),
)
check_tensors(
torch.mm(
(M_ekfac ** (1 / 3)).get_dense_tensor(),
(M_ekfac ** (2 / 3)).get_dense_tensor(),
),
M_ekfac.get_dense_tensor(),
)

# Test inverse
regul = 1e-5
Expand Down

0 comments on commit c03985d

Please sign in to comment.