diff --git a/README.md b/README.md index f622c5aab..732a65801 100644 --- a/README.md +++ b/README.md @@ -334,3 +334,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. +[60] Thornton, James, and Marco Cuturi. [Rethinking initialization of the sinkhorn algorithm](https://arxiv.org/pdf/2206.07630.pdf). International Conference on Artificial Intelligence and Statistics. PMLR, 2023. diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..32077b06b 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -40,7 +40,7 @@ from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle) -from .bregman import sinkhorn, sinkhorn2, barycenter +from .bregman import (sinkhorn, sinkhorn2, barycenter, empirical_sinkhorn, empirical_sinkhorn2, empirical_sinkhorn_divergence) from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm @@ -61,6 +61,7 @@ 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', + 'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn_divergence', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', diff --git a/ot/bregman.py b/ot/bregman.py index c90d89986..30176b4c0 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -23,6 +23,7 @@ from ot.utils import dist, list_to_array, unif from .backend import get_backend +from .gaussian import dual_gaussian_init def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, @@ -541,6 +542,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, log['niter'] = ii log['u'] = u log['v'] = v + log['warmstart'] = (nx.log(u), nx.log(v)) if n_hists: # return only loss res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -697,6 +699,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, 'log_v': nx.stack(lst_v, 1), } log['u'] = nx.exp(log['log_u']) log['v'] = nx.exp(log['log_v']) + log['warmstart'] = (log['log_u'], log['log_v']) return res, log else: return res @@ -2999,15 +3002,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + if warmstart is None: + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + elif warmstart == 'gaussian': + # init only g since f is the first updated + f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None]) + g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None]) + elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2: + f, g = warmstart + else: + raise ValueError( + "warmstart must be None, 'gaussian' or a tuple of two arrays") + if isLazy: if log: dict_log = {"err": []} log_a, log_b = nx.log(a), nx.log(b) - if warmstart is None: - f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) - else: - f, g = warmstart if isinstance(batchSize, int): bs, bt = batchSize, batchSize @@ -3075,6 +3086,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: dict_log["u"] = f dict_log["v"] = g + dict_log["warmstart"] = (f, g) return (f, g, dict_log) else: return (f, g) @@ -3083,11 +3095,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=True, warmstart=warmstart, **kwargs) + verbose=verbose, log=True, warmstart=(f, g), **kwargs) return pi, log else: pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=False, warmstart=warmstart, **kwargs) + verbose=verbose, log=False, warmstart=(f, g), **kwargs) return pi @@ -3201,6 +3213,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + if warmstart is None: + warmstart = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + elif warmstart == 'gaussian': + # init only g since f is the first updated + f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None]) + g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None]) + warmstart = (f, g) + elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2: + warmstart = warmstart + else: + raise ValueError( + "warmstart must be None, 'gaussian' or a tuple of two arrays") + if isLazy: if log: f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, diff --git a/ot/gaussian.py b/ot/gaussian.py index 708f9eb16..6d77ac8a0 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -645,3 +645,51 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, return A, b, log else: return A, b + + +def dual_gaussian_init(xs, xt, ws=None, wt=None, reg=1e-6): + r""" Return the source dual potential gaussian initialization. + + This function return the dual potential gaussian initialization that can be + used to initialize the Sinkhorn algorithm. This initialization is based on + the Monge mapping between the source and target distributions seen as two + Gaussian distributions [60]. + + Parameters + ---------- + xs : array-like (ns,ds) + samples in the source domain + xt : array-like (nt,dt) + samples in the target domain + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + reg : float,optional + regularization added to the diagonals of covariances (>0) + + .. [60] Thornton, James, and Marco Cuturi. "Rethinking initialization of the + sinkhorn algorithm." International Conference on Artificial Intelligence + and Statistics. PMLR, 2023. + """ + + nx = get_backend(xs, xt) + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + # estimate mean and covariance + mu_s = nx.dot(ws.T, xs) / nx.sum(ws) + mu_t = nx.dot(wt.T, xt) / nx.sum(wt) + + A, b = empirical_bures_wasserstein_mapping(xs, xt, ws=ws, wt=wt, reg=reg) + + xsc = xs - mu_s + + # compute the dual potential (see appendix D in [60]) + f = nx.sum(xs**2 - nx.dot(xsc, A) * xsc - mu_t * xs, 1) + + return f diff --git a/test/test_bregman.py b/test/test_bregman.py index 8627df3c6..e3536d383 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1041,6 +1041,9 @@ def test_empirical_sinkhorn(nx): ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian')) + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian @@ -1055,6 +1058,7 @@ def test_empirical_sinkhorn(nx): np.testing.assert_allclose( sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05) def test_lazy_empirical_sinkhorn(nx): @@ -1095,6 +1099,9 @@ def test_lazy_empirical_sinkhorn(nx): loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian', isLazy=True)) + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian @@ -1109,6 +1116,7 @@ def test_lazy_empirical_sinkhorn(nx): np.testing.assert_allclose( sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05) def test_empirical_sinkhorn_divergence(nx): diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 4e3c2df7b..bbbb7dfbb 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -175,3 +175,22 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): if d_target >= 2: np.testing.assert_allclose(Cs, Ctt) + + +def test_gaussian_init(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + a_s = np.ones((ns, 1)) / ns + a_t = np.ones((nt, 1)) / nt + + Xsb, Xtb, a_sb, a_tb = nx.from_numpy(Xs, Xt, a_s, a_t) + + f = ot.gaussian.dual_gaussian_init(Xsb, Xtb) + + f2 = ot.gaussian.dual_gaussian_init(Xsb, Xtb, a_sb, a_tb) + + np.testing.assert_allclose(nx.to_numpy(f), nx.to_numpy(f2))