Source code for geomstats.numerics._common
import functools
from types import MethodType
import numpy as np
from scipy.interpolate import PPoly
import geomstats.backend as gs
[docs]
def result_to_backend_type(result):
"""Convert np.array to gs.array within result object."""
if gs.__name__.endswith("numpy") or gs.__name__.endswith("autograd"):
return result
for key, value in result.items():
if type(value) is np.ndarray:
result[key] = gs.from_numpy(value)
if isinstance(value, PPoly):
new_ppoly = _InstanceConvertOutputWrapper(value)
result[key] = new_ppoly
return result
def _convert_np_output(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
out = func(*args, **kwargs)
if type(out) is np.ndarray:
return gs.from_numpy(out)
return out
return _wrapped
class _InstanceConvertOutputWrapper:
"""Dynamic wrapper for an instance to convert method output to gs.array."""
def __init__(self, instance):
self._instance = instance
self._dict = {}
def __getattr__(self, attr_name):
if attr_name in self._dict:
return self._dict[attr_name]
attr = getattr(self._instance, attr_name)
if isinstance(attr, MethodType):
attr = _convert_np_output(attr)
self._dict[attr_name] = attr
return attr
def __call__(self, *args, **kwargs):
out = self._instance(*args, **kwargs)
if type(out) is np.ndarray:
return gs.from_numpy(out)
return out
def __dir__(self):
return dir(self._instance)
def __repr__(self):
return repr(self._instance)
def __str__(self):
return str(self._instance)