From 644ec946b6582cadc93a652d2d3ec066d483dbab Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 19 Sep 2023 21:44:24 +0000 Subject: [PATCH] Bugfix for ot_barycenter with complex numbers --- concept_erasure/optimal_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/concept_erasure/optimal_transport.py b/concept_erasure/optimal_transport.py index 625d246..1f5acf1 100644 --- a/concept_erasure/optimal_transport.py +++ b/concept_erasure/optimal_transport.py @@ -77,7 +77,7 @@ def ot_barycenter( new_loss = mu.trace() + trace_avg - 2 * inner.mul(weights).sum(dim=0).trace() # Break if the loss is not decreasing - if loss - new_loss < tol: + if loss.real - new_loss.real < tol: break else: loss = new_loss