Source code for geomfum.learning.models

"""Models for learning features for functional maps.

References
----------
.. "Deep Geometric Functional Maps: Robust Feature Learning for Shape Correspondence" by Nicolas Donati, Abhishek Sharma, Maks Ovsjanikov.
.. "Deep Functional Maps: Structured Prediction for Dense Shape Correspondence" by O. Litany, T. Remez, E. Rodola, A. Bronstein, M. Bronstein.
"""

import abc

import gsops.backend as gs
import torch.nn as nn

from geomfum.convert import (
    FmFromP2pConverter,
    P2pFromFmConverter,
    SoftmaxNeighborFinder,
)
from geomfum.descriptor.learned import FeatureExtractor, LearnedDescriptor
from geomfum.forward_functional_map import ForwardFunctionalMap


class BaseModel(abc.ABC, nn.Module):
    """Base class for all models."""


[docs] class FMNet(BaseModel): """Functional Map Network Model. Parameters ---------- feature_extractor : FeatureExtractor Feature extractor to use for the descriptors. fmap_module : ForwardFunctionalMap Functional map module to use for the forward pass. converter : P2pFromFmConverter Converter to convert functional maps to point-to-point correspondences. """ def __init__( self, feature_extractor=FeatureExtractor.from_registry(which="diffusionnet"), fmap_module=ForwardFunctionalMap(), converter=P2pFromFmConverter(), ): super(FMNet, self).__init__() self.feature_extractor = feature_extractor self.descriptors_module = LearnedDescriptor( feature_extractor=self.feature_extractor ) self.fmap_module = fmap_module self.converter = converter
[docs] def forward(self, mesh_a, mesh_b, as_dict=True): """Compute the functional map between two shapes. Parameters ---------- mesh_a : TriangleMesh or dict The first shape, either as a TriangleMesh object or a dictionary containing 'basis', 'evals', and 'pinv'. mesh_b : TriangleMesh or dict The second shape, either as a TriangleMesh object or a dictionary containing 'basis', 'evals', and 'pinv'. as_dict : bool, optional If True, returns a dictionary with functional maps and, optionally, point-to-point correspondences. If False, returns the functional maps and, optionally, point-to-point correspondences as separate tensors. Returns ------- fmap12 : array-like, shape=[..., spectrum_size_b, spectrum_size_a] Functional map from shape a to shape b. fmap21 : array-like, shape=[..., spectrum_size_a, spectrum_size_b] Functional map from shape b to shape a. p2p21 : array-like, shape=[..., num_points_b] Point-to-point correspondence from shape a to shape b. p2p12 : array-like, shape=[..., num_points_a] Point-to-point correspondence from shape b to shape a. """ desc_a = self.descriptors_module(mesh_a) desc_b = self.descriptors_module(mesh_b) fmap12, fmap21 = self.fmap_module(mesh_a, mesh_b, desc_a, desc_b) p2p12 = p2p21 = None if not self.training: p2p21 = self.converter(fmap12, mesh_a.basis, mesh_b.basis) p2p12 = self.converter(fmap21, mesh_b.basis, mesh_a.basis) if as_dict: result = { "fmap12": fmap12, "fmap21": fmap21, "desc_a": desc_a, "desc_b": desc_b, } if not self.training: result.update({"p2p12": p2p12, "p2p21": p2p21}) return result else: return ( (fmap12, fmap21, p2p12, p2p21) if not self.training else (fmap12, fmap21) )
class RobustFMNet(BaseModel): """Functional Map Network Model. Parameters ---------- feature_extractor : FeatureExtractor Feature extractor to use for the descriptors. fmap_module : ForwardFunctionalMap Functional map module to use for the forward pass. converter : P2pFromFmConverter Converter to convert functional maps to point-to-point correspondences. """ def __init__( self, feature_extractor=FeatureExtractor.from_registry(which="diffusionnet"), fmap_module=ForwardFunctionalMap(), converter=P2pFromFmConverter(SoftmaxNeighborFinder(n_neighbors=1, tau=0.07)), ): super(RobustFMNet, self).__init__() self.feature_extractor = feature_extractor self.descriptors_module = LearnedDescriptor( feature_extractor=self.feature_extractor ) self.fmap_module = fmap_module self.converter = converter self.fmap_converter = FmFromP2pConverter(pseudo_inverse=True) self.neighbor_finder = self.converter.neighbor_finder def forward(self, mesh_a, mesh_b, as_dict=True): """Compute the functional map between two shapes. Parameters ---------- mesh_a : TriangleMesh or dict The first shape, either as a TriangleMesh object or a dictionary containing 'basis', 'evals', and 'pinv'. mesh_b : TriangleMesh or dict The second shape, either as a TriangleMesh object or a dictionary containing 'basis', 'evals', and 'pinv'. as_dict : bool, optional If True, returns a dictionary with functional maps and, optionally, point-to-point correspondences. If False, returns the functional maps and, optionally, point-to-point correspondences as separate tensors. Returns ------- fmap12 : array-like, shape=[..., spectrum_size_b, spectrum_size_a] Functional map from shape a to shape b. fmap21 : array-like, shape=[..., spectrum_size_a, spectrum_size_b] Functional map from shape b to shape a. p2p21 : array-like, shape=[..., num_points_b] Point-to-point correspondence from shape a to shape b. p2p12 : array-like, shape=[..., num_points_a] Point-to-point correspondence from shape b to shape a. fmap12_desc : array-like, shape=[..., spectrum_size_b, spectrum_size_a] Functional map from shape a to shape b. fmap21_desc : array-like, shape=[..., spectrum_size_a, spectrum_size_b] Functional map from shape b to shape a. """ desc_a = self.descriptors_module(mesh_a) desc_b = self.descriptors_module(mesh_b) fmap12, fmap21 = self.fmap_module(mesh_a, mesh_b, desc_a, desc_b) desc_a = desc_a / gs.linalg.norm(desc_a, axis=0, keepdims=True) desc_b = desc_b / gs.linalg.norm(desc_b, axis=0, keepdims=True) P12 = self.neighbor_finder.softmax_matrix(desc_a.T, desc_b.T) P21 = self.neighbor_finder.softmax_matrix(desc_b.T, desc_a.T) p2p12 = p2p21 = None fmap21_desc = mesh_a.basis.pinv @ (P12 @ mesh_b.basis.vecs) fmap12_desc = mesh_b.basis.pinv @ (P21 @ mesh_a.basis.vecs) if not self.training: p2p21 = gs.to_device( self.converter(fmap12, mesh_a.basis, mesh_b.basis), "cpu" ) p2p12 = gs.to_device( self.converter(fmap21, mesh_b.basis, mesh_a.basis), "cpu" ) if as_dict: result = { "fmap12": fmap12, "fmap21": fmap21, "fmap12_desc": fmap12_desc, "fmap21_desc": fmap21_desc, } if not self.training: result.update({"p2p12": p2p12, "p2p21": p2p21}) return result else: return ( (fmap12, fmap21, p2p12, p2p21) if not self.training else (fmap12, fmap21, fmap12_desc, fmap21_desc) )