Source code for geomfum.wrap.pot

"""Python Optimal Trasport wrapper."""

import gsops.backend as gs
import ot

from geomfum.convert import BaseNeighborFinder


[docs] class PotSinkhornNeighborFinder(BaseNeighborFinder): """Neighbor finder based on Optimal Transport maps computed with Sinkhorn regularization. Parameters ---------- n_neighbors : int, default=1 Number of neighbors to find. lambd : float, default=1e-1 Regularization parameter for Sinkhorn algorithm. method : str, default="sinkhorn" Method to use for Sinkhorn algorithm. max_iter : int, default=100 Maximum number of iterations for Sinkhorn algorithm. References ---------- .. [Cuturi2013] Marco Cuturi. "Sinkhorn Distances: Lightspeed Computation of Optimal Transport." Advances in Neural Information Processing Systems (NIPS), 2013. http://marcocuturi.net/SI.html """ def __init__(self, n_neighbors=1, lambd=1e-1, method="sinkhorn", max_iter=100): super().__init__(n_neighbors=n_neighbors) self.lambd = lambd self.max_iter = max_iter self.method = method def __call__(self, X, Y): """Find k nearest neighbors using Sinkhorn regularization. Parameters ---------- X : array-like, shape=[n_points_x, n_features] Query points. Y : array-like, shape=[n_points_y, n_features] Reference points. Returns ------- indices : array-like, shape=[n_points_x, n_neighbors] Indices of the nearest neighbors. """ M = gs.exp(-self.lambd * ot.dist(X, Y)) n, m = M.shape a = gs.ones(n) / n b = gs.ones(m) / m # TODO: implement as sinkhorn solver? Gs = ot.sinkhorn(a, b, M, self.lambd, self.method, self.max_iter) indices = gs.argsort(Gs, axis=1)[:, : self.n_neighbors] return indices