Skip to content

Commit

Permalink
Reduced memory usage for shrinkage
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Jan 25, 2024
1 parent 793ab5a commit 2844e7b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion concept_erasure/leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def sigma_xx(self) -> Tensor:

# Apply Random Matrix Theory-based shrinkage
if self.shrinkage:
return optimal_linear_shrinkage(S_hat / self.n, self.n)
return optimal_linear_shrinkage(S_hat / self.n, self.n, inplace=True)

# Just apply Bessel's correction
else:
Expand Down
2 changes: 1 addition & 1 deletion concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def sigma_xx(self) -> Tensor:
# Apply Random Matrix Theory-based shrinkage
n = self.n.view(-1, 1, 1)
if self.shrinkage:
return optimal_linear_shrinkage(S_hat / n, n)
return optimal_linear_shrinkage(S_hat / n, n, inplace=True)

# Just apply Bessel's correction
else:
Expand Down
19 changes: 11 additions & 8 deletions concept_erasure/shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from torch import Tensor


def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
def optimal_linear_shrinkage(
S_n: Tensor, n: int | Tensor, *, inplace: bool = False
) -> Tensor:
"""Optimal linear shrinkage for a sample covariance matrix or batch thereof.
Given a sample covariance matrix `S_n` of shape (*, p, p) and a sample size `n`,
this function computes the optimal shrinkage coefficients `alpha` and `beta`, then
returns the covariance estimate `alpha * S_n + beta * Sigma0`, where ``Sigma0` is
returns the covariance estimate `alpha * S_n + beta * Sigma0`, where `Sigma0` is
an isotropic covariance matrix with the same trace as `S_n`.
The formula is distribution-free and asymptotically optimal in the Frobenius norm
Expand All @@ -26,15 +28,13 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
p = S_n.shape[-1]
assert S_n.shape[-2:] == (p, p)

# TODO: Make this configurable, try using diag(S_n) or something
eye = torch.eye(p, dtype=S_n.dtype, device=S_n.device).expand_as(S_n)
trace_S = trace(S_n)
sigma0 = eye * trace_S / p

sigma0_norm_sq = sigma0.norm(dim=(-2, -1), keepdim=True) ** 2
# Since sigma0 is I * tr(S_n) / p, its squared Frobenius norm is just tr(S_n) ** 2 / p.
sigma0_norm_sq = trace_S ** 2 / p
S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2

prod_trace = trace(S_n @ sigma0)
prod_trace = sigma0_norm_sq #torch.linalg.diagonal(S_n) # trace(S_n @ sigma0)
top = trace_S * trace_S.conj() * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj()

Expand All @@ -45,7 +45,10 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
alpha = 1 - (top + eps) / (bottom + eps)
beta = (1 - alpha) * (prod_trace + eps) / (sigma0_norm_sq + eps)

return alpha * S_n + beta * sigma0
ret = S_n.mul_(alpha) if inplace else alpha * S_n
diag = beta * trace_S / p
torch.linalg.diagonal(ret).add_(diag.squeeze(-1))
return ret


def trace(matrices: Tensor) -> Tensor:
Expand Down

0 comments on commit 2844e7b

Please sign in to comment.