Source code for geomfum.matcher.deep_fmap

"""Deep functional map matcher via per-pair gradient descent."""

import copy

import torch

from geomfum.learning.losses import (
    BijectivityLoss,
    LossManager,
    OrthonormalityLoss,
)
from geomfum.learning.models import FMNet
from geomfum.matcher.base import BaseMatcher, CorrespondenceResult


[docs] class DeepFMMatcher(BaseMatcher): """Deep functional map matcher optimized per shape pair. Jointly optimizes a deep FM model (feature extractor + functional map solver) for a single shape pair using gradient descent, with no dataset-level training. The model is deep-copied at the start of each call, so every pair is solved independently. Passing a pretrained model turns this into a per-pair fine-tuning step. The optimization loop is identical to the learning stage: model.forward() → LossManager.compute_loss() → loss.backward() This means any ``BaseModel`` (``FMNet``, ``RobustFMNet``, …) and any combination of losses from ``geomfum.learning.losses`` can be plugged in. Parameters ---------- model : BaseModel, optional Deep FM model to optimize. Defaults to ``FMNet()`` (random DiffusionNet weights). Pass a pretrained model instance for warm-start fine-tuning. loss_manager : LossManager, optional Weighted combination of losses. Defaults to orthonormality + bijectivity + spectral descriptor preservation (all weight 1). fmap_size : int or tuple of int Number of LBO eigenfunctions for the functional map. A tuple ``(k_b, k_a)`` allows different sizes per shape. n_iters : int Gradient descent iterations per pair. lr : float Adam learning rate. verbose : bool If True, print each loss component every 100 iterations. """ def __init__( self, model=None, loss_manager=None, fmap_size=30, n_iters=1000, lr=1e-3, verbose=False, ): self.model = model or FMNet() self.loss_manager = loss_manager or self._default_loss_manager() self.fmap_size = ( fmap_size if isinstance(fmap_size, tuple) else (fmap_size, fmap_size) ) self.n_iters = n_iters self.lr = lr self.verbose = verbose @staticmethod def _default_loss_manager(): return LossManager( [ OrthonormalityLoss(weight=1.0), BijectivityLoss(weight=1.0), ] ) def __call__(self, shape_a, shape_b): """Optimize per-pair and return correspondences. Both shapes must have a precomputed spectral basis (``shape.basis``). The basis eigenvectors and mass matrix should already be torch tensors (same requirement as for FMNet inference). Parameters ---------- shape_a : Shape First shape (target for p2p21). shape_b : Shape Second shape (source for p2p21). Returns ------- result : CorrespondenceResult Contains ``fmap12``, ``fmap21``, ``p2p21``, ``p2p12``, ``descr_a``, ``descr_b``. """ k_b, k_a = self.fmap_size shape_a.basis.use_k = k_a shape_b.basis.use_k = k_b device = shape_a.basis.vals.device model = copy.deepcopy(self.model).to(device) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) for i in range(self.n_iters): optimizer.zero_grad() outputs = model(shape_a, shape_b, as_dict=True) outputs["shape_a"] = shape_a outputs["shape_b"] = shape_b total_loss, loss_dict = self.loss_manager.compute_loss(outputs) total_loss.backward() optimizer.step() if self.verbose and (i % 100 == 0 or i == self.n_iters - 1): parts = " ".join(f"{k}={v:.4f}" for k, v in loss_dict.items()) print(f" iter {i:4d} | {parts}") model.eval() with torch.no_grad(): outputs = model(shape_a, shape_b, as_dict=True) return CorrespondenceResult( fmap12=outputs["fmap12"], fmap21=outputs.get("fmap21"), p2p21=outputs.get("p2p21"), p2p12=outputs.get("p2p12"), descr_a=outputs.get("desc_a"), descr_b=outputs.get("desc_b"), )