Source code for geomfum._registry

import abc
import inspect
import re
import sys

from geomfum._utils import has_package


class Registry(abc.ABC):
    """Abstract base class for managing registered implementations of components."""

    # whether geomfum provides an implementation
    has_internal = False

    @classmethod
    def register(cls, key, obj_name, requires=(), as_default=False):
        """Register an implementation with optional package dependencies.

        Parameters
        ----------
        key : str
            Key.
        obj_name : class name
            Name of the object to register.
        requires : str or tuple
            Required packages.
        as_default : bool
            Whether to set it as default.
        """
        if isinstance(requires, str):
            requires = [requires]

        for package_name in requires:
            if not has_package(package_name):
                missing_package = package_name
                break
        else:
            missing_package = None

        if as_default:
            cls.default = key

        cls.MAP[key] = (obj_name, missing_package)

    @classmethod
    def get(cls, key=None):
        """Retrieve registered implementation by key.

        Parameters
        ----------
        key : str
            Key.

        Returns
        -------
        Obj : class
            Registered object.
        """
        if key is None:
            key = cls.default

        if key == "geomfum":
            if cls.has_internal:
                return None

            cls.raise_if_no_internal()

        obj_name, missing_package = cls.MAP[key]
        if missing_package:
            raise ModuleNotFoundError(missing_package)

        module = __import__(f"geomfum.wrap.{key}", fromlist=[""])
        Obj = getattr(module, obj_name)

        return Obj

    @classmethod
    def list_available(cls):
        """List all registered implementation keys.

        Returns
        -------
        keys : list
            Registered keys.
        """
        return list(cls.MAP.keys())

    @classmethod
    def only_from_registry(cls):
        """Generate error message listing available implementations.

        Returns
        -------
        msg : str
            Message for no internal implementation with available
            implementations.
        """
        sign = str(inspect.signature(cls.get))[1:-1]
        return (
            f"No internal implementation. Use `.from_registry({sign}, **kwargs)`. "
            "Available implementations: "
            f"{', '.join([str(elem) for elem in cls.list_available()])}."
        )

    @classmethod
    def raise_if_no_internal(cls):
        """Raise ValueError if no internal implementation exists."""
        if not cls.has_internal:
            raise ValueError(cls.only_from_registry())


class NestedRegistry(abc.ABC):
    """Two-level registry for managing implementations with nested categorization."""

    @classmethod
    def _outer_registry(cls, key=None):
        """Retrieve outer registry by key.

        Parameters
        ----------
        key_out : Hashable
            Key for outer register dict.
            Defaults to first key if ``None``.

        Returns
        -------
        registry : Registry
        """
        if key is None:
            return cls.Registries[list(cls.Registries.keys())[0]]
        return cls.Registries[key]

    @classmethod
    def register(cls, key_out, key_in, obj_name, requires=(), as_default=False):
        """Register implementation in nested registry structure.

        Parameters
        ----------
        key_out : Hashable
            Key for outer register dict.
        key_in : str
            Key for object in inner register.
        obj_name : class name
            Name of the object to register.
        requires : str or tuple
            Required packages.
        as_default : bool
            Whether to set it as default.
        """
        return cls._outer_registry(key_out).register(
            key_in, obj_name, requires, as_default
        )

    @classmethod
    def get(cls, key_out, key_in):
        """Retrieve registered implementation using nested keys.

        Parameters
        ----------
        key_out : Hashable
            Key for outer register dict.
        key_in : str
            Key for object in inner register.

        Returns
        -------
        Obj : class
            Registered object.
        """
        return cls._outer_registry(key_out).get(key_in)

    @classmethod
    def list_available(cls, key_out=None):
        """List available implementation keys for specified or all outer categories.

        Returns
        -------
        keys : list or dict[list]
            Registered keys.
        """
        if key_out is not None:
            return cls._outer_registry(key_out).list_available()

        available = {}
        for key, Registry in cls.Registries.items():
            available[key] = Registry.list_available()

        return available

    @classmethod
    def only_from_registry(cls, key_out=None):
        """Generate error message for missing internal implementation.

        Parameters
        ----------
        key_out : Hashable
            Key for outer register dict.
            If ``None``, defaults to default outer key.

        Returns
        -------
        msg : str
            Message for no internal implementation with available
            implementations.
        """
        return cls._outer_registry(key_out).only_from_registry()

    @classmethod
    def raise_if_no_internal(cls, key_out=None):
        """Raise ValueError if no internal implementation exists.

        Parameters
        ----------
        key_out : Hashable
            Key for outer register dict.
            If ``None``, defaults to default outer key.
        """
        return cls._outer_registry(key_out).only_from_registry()


[docs] class WhichRegistryMixins: """Mixin enabling registry-based instantiation via 'which' parameter.""" def __init__(self, *args, **kwargs): self._Registry.raise_if_no_internal() super().__init__(*args, **kwargs)
[docs] @classmethod def from_registry(cls, *args, which=None, **kwargs): """Create instance from registered implementation. Parameters ---------- which : str A registered implementation. Returns ------- obj : BaseHeatKernelSignature Instantiated object. """ instantiator = cls._Registry.get(which) if instantiator is None: obj = cls.__new__(cls) obj.__init__(*args, **kwargs) return obj return instantiator(*args, **kwargs)
[docs] class MeshWhichRegistryMixins: """Mixin for registry-based instantiation with mesh/point cloud distinction.""" def __init__(self, *args, **kwargs): # TODO: has to be improved self._Registry.raise_if_no_internal() super().__init__(*args, **kwargs)
[docs] @classmethod def from_registry(cls, *args, mesh=True, which=None, **kwargs): """Create instance from registered implementation based on shape type. Parameters ---------- mesh : bool Whether a mesh or point cloud. which : str A registered implementation. Returns ------- obj : Obj An instantiated object. """ instantiator = cls._Registry.get(mesh, which) if instantiator is None: obj = cls.__new__(cls) obj.__init__(*args, **kwargs) return obj return instantiator(*args, **kwargs)
class _MeshLaplacianFinderRegistry(Registry): has_internal = True MAP = {} class _PointSetLaplacianFinderRegistry(Registry): has_internal = False MAP = {}
[docs] class LaplacianFinderRegistry(NestedRegistry): Registries = { True: _MeshLaplacianFinderRegistry, False: _PointSetLaplacianFinderRegistry, }
[docs] class HeatKernelSignatureRegistry(Registry): has_internal = True MAP = {}
[docs] class LandmarkHeatKernelSignatureRegistry(Registry): has_internal = True MAP = {}
[docs] class WaveKernelSignatureRegistry(Registry): has_internal = True MAP = {}
[docs] class LandmarkWaveKernelSignatureRegistry(Registry): has_internal = True MAP = {}
[docs] class FaceValuedGradientRegistry(Registry): MAP = {}
[docs] class FaceDivergenceOperatorRegistry(Registry): MAP = {}
[docs] class FaceOrientationOperatorRegistry(Registry): MAP = {}
[docs] class HierarchicalMeshRegistry(Registry): MAP = {}
[docs] class PoissonSamplerRegistry(Registry): MAP = {}
class FarthestPointSamplerRegistry(Registry): MAP = {}
[docs] class FeatureExtractorRegistry(Registry): MAP = {}
[docs] class NeighborFinderRegistry(Registry): has_internal = True MAP = {}
[docs] class MeshPlotterRegistry(Registry): MAP = {}
[docs] class PointCloudPlotterRegistry(Registry): MAP = {}
class _MeshHeatDistanceMetricRegistry(Registry): has_internal = False MAP = {} class _PointSetHeatDistanceMetricRegistry(Registry): has_internal = False MAP = {} class HeatDistanceMetricRegistry(NestedRegistry): Registries = { True: _MeshHeatDistanceMetricRegistry, False: _PointSetHeatDistanceMetricRegistry, } def _create_register_funcs(module): """Create ``register`` functions for each class registry in this module. Given a ``Registry`` (e.g. ``LaplacianFinderRegistry``), it creates a function ``register_`` by removing ``Registry`` from the name and transforming it in snake case (e.g. ``register_laplacian_finder``). These functions are widely used within ``geomfum.wrap``. """ for name, method in inspect.getmembers(module): if not ( hasattr(method, "__bases__") and abc.ABC not in method.__bases__ and name.endswith("Registry") ): continue # upper case split name_ls = ["register"] + [ word.lower() for word in re.findall("[A-Z][^A-Z]*", name)[:-1] ] new_name = "_".join(name_ls) setattr(module, new_name, method.register) _create_register_funcs(sys.modules[__name__])