Source code for geomfum.metric._base

"""Module containing metrics to calculate distances on a Shape."""

import abc

import gsops.backend as gs

from geomfum._registry import HeatDistanceMetricRegistry, MeshWhichRegistryMixins


[docs] class Metric(abc.ABC): """Abstract base class for distance metrics on shapes. Parameters ---------- shape : Shape Considered as a manifold. """ def __init__(self, shape): self._shape = shape
[docs] @abc.abstractmethod def dist(self, point_a, point_b): """Distance between points. Parameters ---------- point_a : array-like, shape=[...] Index Point. point_b : array-like, shape=[...] Index Point. Returns ------- dist : array-like, shape=[...,] Distance. """
[docs] class FinitePointSetMetric(Metric, abc.ABC): """Metric supporting distance matrices and source-to-all computations on discrete point sets."""
[docs] @abc.abstractmethod def dist_matrix(self): """Distances between all the points of a shape. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. """
[docs] @abc.abstractmethod def dist_from_source(self, source_point): """Distances from a source point. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...] or list-like[array-like] Distance. target_point : array-like, shape=[n_targets] or list-like[array-like] Target index. """
[docs] class VertexEuclideanMetric(FinitePointSetMetric): """Euclidean distance metric in ambient embedding space."""
[docs] def dist(self, point_a, point_b): """Distances between shape vertices. Parameters ---------- point_a : array-like, shape=[...] Index of source point. point_b : array-like, shape=[...] Index of target point. Returns ------- dist : array-like, shape=[...] Distance. """ vertices = self._shape.vertices diff = vertices[point_a] - vertices[point_b] return gs.linalg.norm(diff, axis=diff.ndim - 1)
[docs] def dist_from_source(self, source_point): """Distances from source point. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...] or array-like[array-like] Distance. target_point : array-like, shape=[n_targets] or array-like[array-like] Target index. """ vertices = self._shape.vertices source_vertices = vertices[source_point] if source_vertices.ndim > 1: source_vertices = gs.expand_dims(source_vertices, 1) diff = source_vertices - vertices dist = gs.linalg.norm(diff, axis=diff.ndim - 1) target_point = gs.arange(self._shape.n_vertices) if diff.ndim > 1: target_point = gs.broadcast_to( target_point, dist.shape[:-1] + target_point.shape ) return dist, target_point
[docs] def dist_matrix(self): """Distances between all shape vertices. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. """ return self.dist_from_source(gs.arange(self._shape.n_vertices))[0]
[docs] class HeatDistanceMetric(MeshWhichRegistryMixins): """Geodesic distance approximation using the heat method. References ---------- .. [CWW2017] Crane, K., Weischedel, C., Wardetzky, M., 2017. The heat method for distance computation. Commun. ACM 60, 90–99. https://doi.org/10.1145/3131280 """ _Registry = HeatDistanceMetricRegistry
class _SingleDispatchMixins: """Mixin providing scalar-to-batch dispatch for distance computations.""" def dist(self, point_a, point_b): """Distances between mesh vertices. Parameters ---------- point_a : array-like, shape=[...] Index of source point. point_b : array-like, shape=[...] Index of target point. Returns ------- dist : array-like, shape=[...,] Distance. """ point_a = gs.asarray(point_a) point_b = gs.asarray(point_b) if point_a.ndim == 0 and point_b.ndim == 0: return self._dist_single(point_a, point_b) point_a, point_b = gs.broadcast_arrays(point_a, point_b) return gs.stack( [ self._dist_single(point_a_, point_b_) for point_a_, point_b_ in zip(point_a, point_b) ] ) def dist_from_source(self, source_point): """Distance between mesh vertices. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...,] or list[array-like] Distance. target_point : array-like, shape=[n_targets,] or list[array-like] Target index. """ source_point = gs.asarray(source_point) if source_point.ndim == 0: return self._dist_from_source_single(source_point) out = [ self._dist_from_source_single(source_index_) for source_index_ in source_point ] return list(zip(*out)) @abc.abstractmethod def _dist_from_source_single(self, source_point): pass @abc.abstractmethod def _dist_single(self, point_a, point_b): pass