diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index b7372c0..5d73ebf 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -319,8 +319,6 @@ def flat_grad(cls, buffer, mod, layer, x, gy): norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) - # print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size()) - # print((mod.weight * norm2**(-1.5)).unsqueeze(0).size()) gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * ( mod.weight * norm2 ** (-1.5)