Source code for geomfum.descriptor.spectral

"""Spectral descriptors."""

import abc

import gsops.backend as gs

import geomfum.linalg as la
from geomfum._registry import (
    HeatKernelSignatureRegistry,
    LandmarkHeatKernelSignatureRegistry,
    LandmarkWaveKernelSignatureRegistry,
    WaveKernelSignatureRegistry,
    WhichRegistryMixins,
)

from ._base import SpectralDescriptor


[docs] def hks_default_domain(shape, n_domain): """Compute HKS default domain. The domain is a set of sampled time points. Parameters ---------- shape : Shape. Shape with basis. n_domain : int Number of time points. Returns ------- domain : array-like, shape=[n_domain] Time points. """ nonzero_vals = shape.basis.nonzero_vals device = getattr(nonzero_vals, "device", None) return gs.to_device( gs.geomspace( 4 * gs.log(10) / nonzero_vals[-1], 4 * gs.log(10) / nonzero_vals[0], n_domain, ), device, ), None
[docs] class WksDefaultDomain: """Default domain generator for Wave Kernel Signature using logarithmic energy sampling. Parameters ---------- shape : Shape. Shape with basis. n_domain : int Number of energy points to use. n_overlap : int Controls Gaussian overlap. Ignored if ``sigma`` is not None. n_trans : int Number of standard deviations to translate energy bound by. """ def __init__(self, n_domain, sigma=None, n_overlap=7, n_trans=2): self.n_domain = n_domain self.sigma = sigma self.n_overlap = n_overlap self.n_trans = n_trans def __call__(self, shape): """Compute WKS domain. Parameters ---------- shape : Shape. Shape with basis. Returns ------- domain : array-like, shape=[n_domain] sigma : float Standard deviation. """ nonzero_vals = shape.basis.nonzero_vals device = getattr(nonzero_vals, "device", None) e_min, e_max = gs.log(nonzero_vals[0]), gs.log(nonzero_vals[-1]) sigma = ( self.n_overlap * (e_max - e_min) / self.n_domain if self.sigma is None else self.sigma ) e_min += self.n_trans * sigma e_max -= self.n_trans * sigma energy = gs.to_device(gs.linspace(e_min, e_max, self.n_domain), device) return energy, sigma
[docs] class SpectralFilter(abc.ABC): """Abstract base class for computing spectral filter coefficients from eigenvalues.""" @abc.abstractmethod def __call__(self, vals, domain, sigma): """ Compute filter coefficients for the given eigenvalues and domain. Parameters ---------- vals : array-like, shape=[n_eigen] Eigenvalues. domain : array-like, shape=[n_domain] Domain points (e.g., time for HKS, energy for WKS). sigma : float or None Optional parameter for the filter (e.g., standard deviation for WKS). Returns ------- coefs : array-like, shape=[n_domain, n_eigen] Filter coefficients. """
[docs] class HeatKernelFilter(SpectralFilter): """Heat kernel filter computing exp(-t * λ) coefficients for HKS.""" def __call__(self, vals, domain, sigma): """ Compute heat kernel filter coefficients. Parameters ---------- vals : array-like, shape=[n_eigen] Eigenvalues. domain : array-like, shape=[n_domain] Time points. sigma : float or None Unused for heat kernel filter. Returns ------- coefs : array-like, shape=[n_domain, n_eigen] Filter coefficients. """ exp_arg = -la.scalarvecmul(domain, vals) return gs.exp(exp_arg)
[docs] class WaveKernelFilter(SpectralFilter): """Wave kernel filter using Gaussian weighting in log-eigenvalue space for WKS.""" def __call__(self, vals, domain, sigma): """ Compute wave kernel filter coefficients. Parameters ---------- vals : array-like, shape=[n_eigen] Eigenvalues. domain : array-like, shape=[n_domain] Energy points (log-space). sigma : float Standard deviation for the Gaussian. Returns ------- coefs : array-like, shape=[n_domain, n_eigen] Filter coefficients. """ nonzero_vals = vals[gs.sum(gs.isclose(vals, 0.0)) :] zeros = gs.to_device( gs.zeros((domain.shape[0], vals.shape[0] - nonzero_vals.shape[0])), device=getattr(nonzero_vals, "device", None), ) exp_arg = -gs.square(gs.log(nonzero_vals) - domain[:, None]) / ( 2 * gs.square(sigma) ) coefs = gs.exp(exp_arg) if zeros.shape[1] > 0: coefs = gs.concatenate([zeros, coefs], axis=1) return coefs
[docs] class HeatKernelSignature(WhichRegistryMixins, SpectralDescriptor): """Heat Kernel Signature descriptor using heat diffusion over time. Parameters ---------- scale : bool Whether to scale weights to sum to one. n_domain : int Number of domain points. Ignored if ``domain`` is not None. domain : callable or array-like, shape=[n_domain], optional Method to compute time domain points (``f(shape)``) or time domain points. k : int, optional Number of eigenfunctions to use. If None, all eigenfunctions are used. """ _Registry = HeatKernelSignatureRegistry def __init__(self, scale=True, n_domain=3, domain=None, k=None): super().__init__( spectral_filter=HeatKernelFilter(), domain=domain or (lambda shape: hks_default_domain(shape, n_domain=n_domain)), scale=scale, sigma=1, landmarks=False, k=k, )
[docs] class WaveKernelSignature(WhichRegistryMixins, SpectralDescriptor): """Wave Kernel Signature descriptor using quantum mechanical wave propagation. Parameters ---------- scale : bool Whether to scale weights to sum to one. sigma : float Standard deviation for the Gaussian. n_domain : int Number of domain points. Ignored if ``domain`` is not None. domain : callable or array-like, shape=[n_domain], optional Method to compute energy domain points (``f(shape)``) or energy domain points. k : int, optional Number of eigenfunctions to use. If None, all eigenfunctions are used. """ _Registry = WaveKernelSignatureRegistry def __init__(self, scale=True, sigma=None, n_domain=3, domain=None, k=None): domain = domain or WksDefaultDomain(n_domain=n_domain, sigma=sigma) super().__init__( spectral_filter=WaveKernelFilter(), domain=domain, scale=scale, sigma=sigma, landmarks=False, k=k, )
[docs] class LandmarkHeatKernelSignature(WhichRegistryMixins, SpectralDescriptor): """Heat Kernel Signature computed from landmark points. Parameters ---------- scale : bool Whether to scale weights to sum to one. n_domain : int Number of domain points. Ignored if ``domain`` is not None. domain : callable or array-like, shape=[n_domain], optional Method to compute time domain points (``f(shape)``) or time domain points. k : int, optional Number of eigenfunctions to use. """ _Registry = LandmarkHeatKernelSignatureRegistry def __init__(self, scale=True, n_domain=3, domain=None, k=None): super().__init__( spectral_filter=HeatKernelFilter(), domain=domain or (lambda shape: hks_default_domain(shape, n_domain=n_domain)), scale=scale, sigma=1, landmarks=True, k=k, )
[docs] class LandmarkWaveKernelSignature(WhichRegistryMixins, SpectralDescriptor): """Wave Kernel Signature computed from landmark points. Parameters ---------- scale : bool Whether to scale weights to sum to one. sigma : float Standard deviation for the Gaussian. n_domain : int Number of domain points. Ignored if ``domain`` is not None. domain : callable or array-like, shape=[n_domain], optional Method to compute energy domain points (``f(shape)``) or energy domain points. k : int, optional Number of eigenfunctions to use. """ _Registry = LandmarkWaveKernelSignatureRegistry def __init__(self, scale=True, sigma=None, n_domain=3, domain=None, k=None): super().__init__( spectral_filter=WaveKernelFilter(), domain=domain or WksDefaultDomain(n_domain=n_domain, sigma=sigma), scale=scale, sigma=sigma, k=k, landmarks=True, )