Source code for geomfum.evaluation

"""Correspondence evaluation metrics and manager.

Notation
--------
- ``p2p21``: point-to-point map from shape_b to shape_a.
  ``p2p21[i]`` = vertex in A that vertex i in B maps to.
- ``dist_a``: geodesic distance matrix on shape A.
- ``corr_a``, ``corr_b``: ground-truth correspondence indices into A and B.
- ``mask_a``: binary overlap mask on A (1 = in overlap region).

Inputs dict
-----------
``EvaluationManager.compute`` assembles a flat dict from:

1. ``CorrespondenceResult.to_dict()`` (or any plain dict) —
   keys: ``p2p21``, ``fmap12``, ``soft_perm_ba``, ``overlap_ab``, …
2. Explicit context kwargs:
   ``shape_a``, ``shape_b``, ``corr_a``, ``corr_b``, ``dist_a``, ``mask_a``.

Each metric declares ``required_inputs`` listing the keys it needs.
Metrics whose required inputs are absent are silently skipped.
"""

import abc

import gsops.backend as gs

import geomfum.linalg as la
from geomfum.metric._base import VertexEuclideanMetric

# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------


[docs] class CorrespondenceMetric(abc.ABC): """Base class for correspondence evaluation metrics. Follows the same ``required_inputs`` pattern as ``geomfum.learning.losses`` loss classes, but operates on backend arrays via ``gsops.backend`` (no gradients) and receives inputs as a single flat dict. Subclasses must declare ``required_inputs`` and implement ``__call__``. """ required_inputs: list = [] @abc.abstractmethod def __call__(self, inputs: dict) -> float: """Compute the metric from the flat inputs dict. Parameters ---------- inputs : dict Flat dict assembled by ``EvaluationManager``. All keys in ``required_inputs`` are guaranteed to be present. Returns ------- float """
# --------------------------------------------------------------------------- # Full-shape metrics # ---------------------------------------------------------------------------
[docs] class GeodesicErrorMetric(CorrespondenceMetric): """Normalised mean geodesic error of ``p2p21``. Error = mean(dist_a[p2p21[corr_b], corr_a]) / diam_a. Falls back to identity correspondence when ``corr_a``/``corr_b`` absent. Required inputs: ``p2p21``, ``dist_a``. Optional inputs: ``corr_a``, ``corr_b``. """ required_inputs = ["p2p21", "dist_a"] def __call__(self, inputs): """Compute the normalised mean geodesic error.""" dist_a = gs.asarray(inputs["dist_a"]) p2p21 = gs.asarray(inputs["p2p21"]) corr_a = inputs.get("corr_a") corr_b = inputs.get("corr_b") if corr_a is None or corr_b is None: error = gs.mean(dist_a[p2p21, gs.arange(p2p21.shape[0])]) else: error = gs.mean(dist_a[p2p21[gs.asarray(corr_b)], gs.asarray(corr_a)]) diam = gs.amax(dist_a) return float(error / diam) if diam > 0 else 0.0
[docs] class EuclideanErrorMetric(CorrespondenceMetric): """Normalised mean Euclidean error of ``p2p21``. Error = mean(||v_a[p2p21[corr_b]] - v_a[corr_a]||) / eucl_diam_a. Falls back to identity when ``corr_a``/``corr_b`` absent. Required inputs: ``p2p21``, ``shape_a``. Optional inputs: ``corr_a``, ``corr_b``. """ required_inputs = ["p2p21", "shape_a"] def __call__(self, inputs): """Compute the normalised mean Euclidean error.""" p2p21 = gs.asarray(inputs["p2p21"]) shape_a = inputs["shape_a"] corr_a = inputs.get("corr_a") corr_b = inputs.get("corr_b") verts = shape_a.vertices if corr_a is None or corr_b is None: pred = verts[p2p21] gt = verts[gs.arange(p2p21.shape[0])] else: pred = verts[p2p21[gs.asarray(corr_b)]] gt = verts[gs.asarray(corr_a)] diff = pred - gt error = gs.mean(gs.linalg.norm(diff, axis=diff.ndim - 1)) diam = gs.amax(VertexEuclideanMetric(shape_a).dist_matrix()) return float(error / diam) if diam > 0 else 0.0
[docs] class DirichletEnergyMetric(CorrespondenceMetric): """Dirichlet energy of the pulled-back embedding (smoothness proxy). Pulls shape A's vertex coordinates onto B through ``p2p21`` and measures their Dirichlet energy on B, normalised to be scale- and resolution- invariant: E = tr(Xᵀ W X) / tr(Xᵀ M X), where ``X = v_a[p2p21]``, ``W`` is B's stiffness (cotan) matrix and ``M`` its mass matrix. Lower is smoother. Note: Dirichlet energy is minimised by *collapsed* maps (all of B sent to one point of A → 0), so read it alongside ``CoverageMetric``. Required inputs: ``p2p21``, ``shape_a``, ``shape_b``. """ required_inputs = ["p2p21", "shape_a", "shape_b"] def __call__(self, inputs): """Compute the (mass-normalised) Dirichlet energy.""" p2p21 = gs.asarray(inputs["p2p21"]) shape_a = inputs["shape_a"] shape_b = inputs["shape_b"] x = shape_a.vertices[p2p21] # [n_b, 3] pulled-back coordinates xt = gs.transpose(x) # [3, n_b] stiffness = shape_b.laplacian.stiffness_matrix mass = shape_b.laplacian.mass_matrix # la.matvecmul(A, xt) returns (A X)^T with shape [3, n_b]. wx = la.matvecmul(stiffness, xt) mx = la.matvecmul(mass, xt) numerator = gs.sum(wx * xt) # tr(Xᵀ W X) denominator = gs.sum(mx * xt) # tr(Xᵀ M X) return float(numerator / denominator)
[docs] class CoverageMetric(CorrespondenceMetric): """Area-weighted fraction of shape A covered by ``p2p21``. Required inputs: ``p2p21``, ``shape_a``. """ required_inputs = ["p2p21", "shape_a"] def __call__(self, inputs): """Compute the coverage fraction.""" p2p21 = gs.asarray(inputs["p2p21"]) shape_a = inputs["shape_a"] areas = gs.asarray(shape_a.vertex_areas) unique = gs.unique(p2p21) return float(gs.sum(areas[unique]) / gs.sum(areas))
[docs] class CoverageCountMetric(CorrespondenceMetric): """Count-based fraction of shape A vertices reached by ``p2p21``. Required inputs: ``p2p21``, ``shape_a``. """ required_inputs = ["p2p21", "shape_a"] def __call__(self, inputs): """Compute the count-based coverage fraction.""" p2p21 = gs.asarray(inputs["p2p21"]) shape_a = inputs["shape_a"] return float(gs.unique(p2p21).shape[0] / shape_a.n_vertices)
[docs] class SoftGeodesicErrorMetric(CorrespondenceMetric): """Expected geodesic error under a soft permutation ``soft_perm_ba``. Required inputs: ``soft_perm_ba``, ``dist_a``. Optional inputs: ``corr_a``, ``corr_b``. """ required_inputs = ["soft_perm_ba", "dist_a"] def __call__(self, inputs): """Compute the expected geodesic error under a soft permutation.""" soft_perm = gs.asarray(inputs["soft_perm_ba"]) # [n_b, n_a] dist_a = gs.asarray(inputs["dist_a"]) # [n_a, n_a] corr_a = inputs.get("corr_a") corr_b = inputs.get("corr_b") diam = gs.amax(dist_a) if corr_a is None or corr_b is None: expected = gs.sum(soft_perm * gs.transpose(dist_a), axis=1) else: corr_a = gs.asarray(corr_a) corr_b = gs.asarray(corr_b) perm_rows = soft_perm[corr_b, :] # [n_corr, n_a] gt_dists = gs.transpose(dist_a[:, corr_a]) # [n_corr, n_a] expected = gs.sum(perm_rows * gt_dists, axis=1) return float(gs.mean(expected) / diam) if diam > 0 else 0.0
# --------------------------------------------------------------------------- # Partial-shape metrics # ---------------------------------------------------------------------------
[docs] class PartialGeodesicErrorMetric(CorrespondenceMetric): """Geodesic error restricted to the GT overlap region (mask_a). Follows the filtered protocol from EchoMatch / SHREC16. Required inputs: ``p2p21``, ``dist_a``, ``corr_a``, ``corr_b``, ``mask_a``. """ required_inputs = ["p2p21", "dist_a", "corr_a", "corr_b", "mask_a"] def __call__(self, inputs): """Compute the normalised mean geodesic error.""" dist_a = gs.asarray(inputs["dist_a"]) p2p21 = gs.asarray(inputs["p2p21"]) corr_a = gs.asarray(inputs["corr_a"]) corr_b = gs.asarray(inputs["corr_b"]) mask_a = gs.asarray(inputs["mask_a"]) valid = mask_a[corr_a] > 0.5 if int(gs.sum(valid)) == 0: return 0.0 return float(gs.mean(dist_a[corr_a[valid], p2p21[corr_b[valid]]]))
[docs] class OverlapIoUMetric(CorrespondenceMetric): """IoU between predicted overlap scores and GT overlap mask on A. Parameters ---------- threshold : float Binarisation threshold for ``overlap_ab``. Default 0.5. Required inputs: ``overlap_ab``, ``mask_a``. """ required_inputs = ["overlap_ab", "mask_a"] def __init__(self, threshold: float = 0.5): self.threshold = threshold def __call__(self, inputs): """Compute the IoU between predicted and GT overlap.""" pred = gs.asarray(inputs["overlap_ab"]) >= self.threshold gt = gs.asarray(inputs["mask_a"]) >= 0.5 intersection = gs.sum(pred & gt) union = gs.sum(pred | gt) return 1.0 if int(union) == 0 else float(intersection) / float(union)
[docs] class PCKAucMetric(CorrespondenceMetric): """AUC of the PCK curve, evaluated over the GT overlap region. Parameters ---------- t_max : float Maximum normalised geodesic threshold. Default 0.20. n_steps : int Number of threshold steps. Default 100. Required inputs: ``p2p21``, ``dist_a``, ``corr_a``, ``corr_b``, ``mask_a``. """ required_inputs = ["p2p21", "dist_a", "corr_a", "corr_b", "mask_a"] def __init__(self, t_max: float = 0.20, n_steps: int = 100): self.t_max = t_max self.n_steps = n_steps def __call__(self, inputs): """Compute the AUC of the PCK curve over the GT overlap region.""" dist_a = gs.asarray(inputs["dist_a"]) p2p21 = gs.asarray(inputs["p2p21"]) corr_a = gs.asarray(inputs["corr_a"]) corr_b = gs.asarray(inputs["corr_b"]) mask_a = gs.asarray(inputs["mask_a"]) valid = mask_a[corr_a] > 0.5 if int(gs.sum(valid)) == 0: return 0.0 geo_err = dist_a[corr_a[valid], p2p21[corr_b[valid]]] diam = gs.amax(dist_a) geo_err_norm = geo_err / diam if diam > 0 else geo_err thresholds = gs.linspace(0.0, self.t_max, self.n_steps) # [n_steps] # hits[s, i] = geo_err_norm[i] <= thresholds[s] hits = geo_err_norm[None, :] <= thresholds[:, None] pck = gs.sum(1.0 * hits, axis=1) / geo_err_norm.shape[0] # [n_steps] return float(gs.trapezoid(pck, thresholds) / self.t_max)
# --------------------------------------------------------------------------- # Manager # ---------------------------------------------------------------------------
[docs] class EvaluationManager: """Manages a set of correspondence evaluation metrics. Analogous to ``geomfum.learning.losses.LossManager`` but for post-hoc evaluation: no gradients, no weighting, graceful skipping. Assembles a flat inputs dict from a ``CorrespondenceResult`` (or plain dict) plus explicit context kwargs, then runs every metric whose ``required_inputs`` are satisfied. Parameters ---------- metrics : list of CorrespondenceMetric, or dict of str → CorrespondenceMetric If a list, metric names are taken from the class name. Examples -------- >>> evaluator = EvaluationManager([ ... GeodesicErrorMetric(), ... CoverageMetric(), ... SoftGeodesicErrorMetric(), ... ]) >>> result = matcher(shape_a, shape_b) # CorrespondenceResult >>> scores = evaluator.compute( ... result, shape_a=shape_a, shape_b=shape_b, ... corr_a=corr_a, corr_b=corr_b, dist_a=dist_a, ... ) >>> # {"GeodesicErrorMetric": 0.043, "CoverageMetric": 0.97} """ def __init__(self, metrics): if isinstance(metrics, list): self.metrics = {type(m).__name__: m for m in metrics} else: self.metrics = dict(metrics)
[docs] def compute( self, result, *, shape_a=None, shape_b=None, corr_a=None, corr_b=None, dist_a=None, mask_a=None, ) -> dict: """Compute all applicable metrics. Parameters ---------- result : CorrespondenceResult or dict Matcher output. ``CorrespondenceResult.to_dict()`` is called automatically; non-None fields are added to the inputs dict. shape_a, shape_b : Shape, optional corr_a, corr_b : array-like, optional Ground-truth correspondence indices. dist_a : array-like, optional Geodesic distance matrix on shape A. mask_a : array-like, optional Binary overlap mask on shape A. Returns ------- dict of str → float Only metrics whose required inputs were present are included. """ if hasattr(result, "to_dict"): inputs = result.to_dict() else: inputs = {k: v for k, v in result.items() if v is not None} for key, val in [ ("shape_a", shape_a), ("shape_b", shape_b), ("corr_a", corr_a), ("corr_b", corr_b), ("dist_a", dist_a), ("mask_a", mask_a), ]: if val is not None: inputs[key] = val results = {} for name, metric in self.metrics.items(): if all(k in inputs for k in metric.required_inputs): results[name] = float(metric(inputs)) return results