Source code for geomstats.numerics._common
import functools
import inspect
from types import MethodType
import numpy as np
from scipy.interpolate import PPoly
import geomstats.backend as gs
[docs]
def result_to_backend_type(result):
"""Convert np.array to gs.array within result object."""
if gs.__name__.endswith("numpy") or gs.__name__.endswith("autograd"):
return result
for key, value in result.items():
if type(value) is np.ndarray:
result[key] = gs.from_numpy(value)
if isinstance(value, PPoly):
new_ppoly = _InstanceConvertOutputWrapper(value)
result[key] = new_ppoly
return result
def _convert_np_output(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
out = func(*args, **kwargs)
if type(out) is np.ndarray:
return gs.from_numpy(out)
return out
return _wrapped
class _InstanceConvertOutputWrapper:
"""Dynamic wrapper for an instance to convert method output to gs.array."""
def __init__(self, instance):
self._instance = instance
self._dict = {}
def __getattr__(self, attr_name):
if attr_name in self._dict:
return self._dict[attr_name]
attr = getattr(self._instance, attr_name)
if isinstance(attr, MethodType):
attr = _convert_np_output(attr)
self._dict[attr_name] = attr
return attr
def __call__(self, *args, **kwargs):
out = self._instance(*args, **kwargs)
if type(out) is np.ndarray:
return gs.from_numpy(out)
return out
def __dir__(self):
return dir(self._instance)
def __repr__(self):
return repr(self._instance)
def __str__(self):
return str(self._instance)
def params_to_kwargs(obj, ignore=(), renamings=None, ignore_private=False, func=None):
"""Get dict with selected object attributes.
Parameters
----------
obj : object
Object with desired attributes.
ignore : tuple[str]
Attributes to ignore.
renamings: dict
Attribute renamings.
ignore_private: bool
Whether to ignore private attributes.
func : callable
Function to get signature from. Attributes
not in the signature are ignored.
Returns
-------
kwargs : dict
"""
kwargs = obj.__dict__.copy()
if func is not None:
params = inspect.signature(func).parameters
ignore = list(ignore) + [key for key in kwargs if key not in params]
if ignore:
for key in ignore:
kwargs.pop(key)
if renamings is not None:
for old_key, new_key in renamings.items():
kwargs[new_key] = kwargs.pop(old_key)
if ignore_private:
private_keys = list(filter(lambda key: key.startswith("_"), kwargs.keys()))
for key in private_keys:
kwargs.pop(key)
return kwargs