From 2cff73e3845f3f584d6fc131adb5822110e6b9cc Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 25 Jan 2024 05:07:24 +0000 Subject: [PATCH] Fixed ruff formatting issue --- concept_erasure/shrinkage.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/concept_erasure/shrinkage.py b/concept_erasure/shrinkage.py index 16fac2c..eab4291 100644 --- a/concept_erasure/shrinkage.py +++ b/concept_erasure/shrinkage.py @@ -30,11 +30,11 @@ def optimal_linear_shrinkage( trace_S = trace(S_n) - # Since sigma0 is I * tr(S_n) / p, its squared Frobenius norm is just tr(S_n) ** 2 / p. - sigma0_norm_sq = trace_S ** 2 / p + # Since sigma0 is I * tr(S_n) / p, its squared Frobenius norm is tr(S_n) ** 2 / p. + sigma0_norm_sq = trace_S**2 / p S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2 - prod_trace = sigma0_norm_sq #torch.linalg.diagonal(S_n) # trace(S_n @ sigma0) + prod_trace = sigma0_norm_sq top = trace_S * trace_S.conj() * sigma0_norm_sq / n bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj()