Source code for geomfum.matcher.base

"""Base classes for shape matchers."""

import abc
from dataclasses import dataclass

import gsops.backend as gs

from geomfum.convert import BaseNeighborFinder, NeighborFinder
from geomfum.descriptor.pipeline import DescriptorPipeline, L2InnerNormalizer
from geomfum.descriptor.spectral import WaveKernelSignature


[docs] @dataclass class CorrespondenceResult: """Result of a matching operation (for both Matcher and Model). This is the unified output format for all correspondence methods, including classical functional map matchers and learning-based models. Parameters ---------- fmap12 : array-like, shape=[spectrum_size_b, spectrum_size_a] Functional map matrix from shape_a to shape_b. p2p21 : array-like, shape=[n_vertices_b] Point-to-point correspondence from shape_b to shape_a. For each vertex i in shape_b, p2p21[i] gives the corresponding vertex index in shape_a. fmap21 : array-like, shape=[spectrum_size_a, spectrum_size_b], optional Functional map matrix from shape_b to shape_a (for bidirectional). p2p12 : array-like, shape=[n_vertices_a], optional Point-to-point correspondence from shape_a to shape_b (for bidirectional). descr_a : array-like, shape=[n_descr, n_vertices_a], optional Descriptors on shape_a. descr_b : array-like, shape=[n_descr, n_vertices_b], optional Descriptors on shape_b. refined_fmap12 : array-like, shape=[spectrum_size_b, spectrum_size_a], optional Refined functional map matrix (if refinement was applied). refined_fmap21 : array-like, shape=[spectrum_size_a, spectrum_size_b], optional Refined functional map matrix from B to A (if bidirectional refinement). soft_perm_ab : array-like, shape=[n_vertices_a, n_vertices_b], optional Soft permutation matrix mapping b vertices to a domain (P12 in RobustFMNet). soft_perm_ab[i, j] = probability that vertex i in a corresponds to vertex j in b. soft_perm_ba : array-like, shape=[n_vertices_b, n_vertices_a], optional Soft permutation matrix mapping a vertices to b domain (P21 in RobustFMNet). soft_perm_ba[i, j] = probability that vertex i in b corresponds to vertex j in a. overlap_ab : array-like, shape=[n_vertices_a], optional Predicted overlap scores (0–1) for shape A. overlap_ab[i] = probability that vertex i in A is in the overlap region with shape B. overlap_ba : array-like, shape=[n_vertices_b], optional Predicted overlap scores (0–1) for shape B. overlap_ba[j] = probability that vertex j in B is in the overlap region with shape A. """ fmap12: "gs.ndarray" = None p2p21: "gs.ndarray" = None fmap21: "gs.ndarray" = None p2p12: "gs.ndarray" = None descr_a: "gs.ndarray" = None descr_b: "gs.ndarray" = None refined_fmap12: "gs.ndarray" = None refined_fmap21: "gs.ndarray" = None soft_perm_ab: "gs.ndarray" = None soft_perm_ba: "gs.ndarray" = None overlap_ab: "gs.ndarray" = None overlap_ba: "gs.ndarray" = None
[docs] def to_dict(self): """Convert to dictionary (for backward compatibility). Returns ------- dict Dictionary with all non-None fields. Notes ----- This method avoids using `asdict()` from dataclasses because it performs deep copying, which fails for PyTorch tensors that are part of the computation graph (non-leaf tensors) during training. """ return { k: getattr(self, k) for k in self.__dataclass_fields__ if getattr(self, k) is not None }
@property def is_bidirectional(self): """Check if result contains bidirectional correspondences. Returns ------- bool True if fmap21 and p2p12 are available. """ return self.fmap21 is not None and self.p2p12 is not None
[docs] class BaseMatcher(abc.ABC): """Abstract base class for shape matchers.""" @abc.abstractmethod def __call__(self, shape_a, shape_b): """Compute correspondence between two shapes. Parameters ---------- shape_a : Shape First shape (target for p2p21). shape_b : Shape Second shape (source for p2p21). Returns ------- result : CorrespondenceResult Correspondence result containing: - p2p21: point-to-point correspondence from B to A - fmap12: functional map from A to B (if applicable) """
[docs] class DescriptorMatcher(BaseMatcher): """Descriptor-based matcher using nearest neighbor in descriptor space. This matcher directly computes correspondences by: 1. Computing descriptors/features on both shapes either indicating descriptor or pipeline. 2. Finding nearest neighbors in the descriptor space 3. Optionally refining the correspondence using CorrespondenceRefinementPipeline This is simpler and faster than the functional map approach, but may be less robust for complex deformations. Parameters ---------- descriptor_pipeline : DescriptorPipeline, optional Descriptor pipeline to compute descriptors. If None, uses default WKS-based pipeline. neighbor_finder : NeighborFinder, optional Nearest neighbor finder. If None, uses default. """ def __init__( self, descriptor_pipeline: DescriptorPipeline = None, neighbor_finder: BaseNeighborFinder = None, ): self.descriptor_pipeline = ( descriptor_pipeline or self._build_default_descriptor_pipeline() ) self.neighbor_finder = neighbor_finder or NeighborFinder(n_neighbors=1) def _build_default_descriptor_pipeline(self): """Build the default descriptor pipeline. Returns ------- pipeline : DescriptorPipeline """ return DescriptorPipeline( [WaveKernelSignature(n_domain=200, k=200), L2InnerNormalizer()] ) def __call__(self, shape_a, shape_b, bidirectional=False): """Compute correspondence between two shapes. Parameters ---------- shape_a : Shape First shape (target for p2p21). shape_b : Shape Second shape (source for p2p21). bidirectional : bool If True, compute correspondences in both directions. Returns ------- result : CorrespondenceResult Matching result containing: - p2p21: point-to-point correspondence from B to A - p2p12: (if bidirectional=True) correspondence from A to B """ # Compute descriptors descr_a = self.descriptor_pipeline.apply(shape_a) descr_b = self.descriptor_pipeline.apply(shape_b) # Find nearest neighbors in descriptor space # descr shape is [n_descr, n_vertices], we need [n_vertices, n_descr] feat_a = descr_a.T feat_b = descr_b.T # Find for each vertex in B, the nearest vertex in A (p2p21: B -> A) p2p21 = self.neighbor_finder(feat_b, feat_a).flatten() # Compute reverse direction if bidirectional p2p12 = None if bidirectional: p2p12 = self.neighbor_finder(feat_a, feat_b).flatten() return CorrespondenceResult( p2p21=p2p21, p2p12=p2p12, descr_a=descr_a, descr_b=descr_b, )
[docs] class SpatialNearestNeighborMatcher(BaseMatcher): """Matcher based on spatial nearest neighbors. This matcher computes correspondences by finding the nearest vertex in Euclidean space. It does not use any descriptors or functional maps. Parameters ---------- neighbor_finder : BaseNeighborFinder, optional Nearest neighbor finder. If None, uses default. """ def __init__( self, neighbor_finder: BaseNeighborFinder = None, ): self.neighbor_finder = neighbor_finder or NeighborFinder(n_neighbors=1) def __call__(self, shape_a, shape_b, bidirectional=False): """Compute correspondence between two shapes. Parameters ---------- shape_a : Shape First shape (target for p2p21). shape_b : Shape Second shape (source for p2p21). bidirectional : bool If True, compute correspondences in both directions. Returns ------- result : CorrespondenceResult Matching result containing: - p2p21: point-to-point correspondence from B to A - p2p12: (if bidirectional=True) correspondence from A to B """ # Find for each vertex in B, the nearest vertex in A (p2p21: B -> A) p2p21 = self.neighbor_finder(shape_b.vertices, shape_a.vertices).flatten() # Compute reverse direction if bidirectional p2p12 = None if bidirectional: p2p12 = self.neighbor_finder(shape_a.vertices, shape_b.vertices).flatten() return CorrespondenceResult( p2p21=p2p21, p2p12=p2p12, descr_a=None, descr_b=None, )