diff --git a/concept_erasure/optimal_transport.py b/concept_erasure/optimal_transport.py index 625d246..1f5acf1 100644 --- a/concept_erasure/optimal_transport.py +++ b/concept_erasure/optimal_transport.py @@ -77,7 +77,7 @@ def ot_barycenter( new_loss = mu.trace() + trace_avg - 2 * inner.mul(weights).sum(dim=0).trace() # Break if the loss is not decreasing - if loss - new_loss < tol: + if loss.real - new_loss.real < tol: break else: loss = new_loss