mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
new slicing done and first attempts at copy and pickling full models
This commit is contained in:
parent
c65a1e3544
commit
e26e737014
2 changed files with 103 additions and 86 deletions
|
|
@ -5,118 +5,134 @@ Created on 11 Mar 2014
|
|||
'''
|
||||
from ...core.parameterization.parameterized import ParametersChangedMeta
|
||||
import numpy as np
|
||||
from functools import wraps
|
||||
|
||||
def put_clean(dct, name, *args, **kw):
|
||||
def put_clean(dct, name, func):
|
||||
if name in dct:
|
||||
dct['_clean_{}'.format(name)] = dct[name]
|
||||
dct[name] = _slice_wrapper(None, dct[name], *args, **kw)
|
||||
dct[name] = func(dct[name])
|
||||
|
||||
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
||||
def __new__(cls, name, bases, dct):
|
||||
put_clean(dct, 'K')
|
||||
put_clean(dct, 'Kdiag', diag=True)
|
||||
put_clean(dct, 'update_gradients_full', diag=False, derivative=True)
|
||||
put_clean(dct, 'gradients_X', diag=False, derivative=True, ret_X=True)
|
||||
put_clean(dct, 'gradients_X_diag', diag=True, derivative=True, ret_X=True)
|
||||
put_clean(dct, 'psi0', diag=False, derivative=False)
|
||||
put_clean(dct, 'psi1', diag=False, derivative=False)
|
||||
put_clean(dct, 'psi2', diag=False, derivative=False)
|
||||
put_clean(dct, 'update_gradients_expectations', derivative=True, psi_stat=True)
|
||||
put_clean(dct, 'gradients_Z_expectations', derivative=True, psi_stat_Z=True, ret_X=True)
|
||||
put_clean(dct, 'gradients_qX_expectations', derivative=True, psi_stat=True, ret_X=True)
|
||||
put_clean(dct, 'K', _slice_K)
|
||||
put_clean(dct, 'Kdiag', _slice_Kdiag)
|
||||
put_clean(dct, 'update_gradients_full', _slice_update_gradients_full)
|
||||
put_clean(dct, 'update_gradients_diag', _slice_update_gradients_diag)
|
||||
put_clean(dct, 'gradients_X', _slice_gradients_X)
|
||||
put_clean(dct, 'gradients_X_diag', _slice_gradients_X_diag)
|
||||
|
||||
put_clean(dct, 'psi0', _slice_psi)
|
||||
put_clean(dct, 'psi1', _slice_psi)
|
||||
put_clean(dct, 'psi2', _slice_psi)
|
||||
put_clean(dct, 'update_gradients_expectations', _slice_update_gradients_expectations)
|
||||
put_clean(dct, 'gradients_Z_expectations', _slice_gradients_Z_expectations)
|
||||
put_clean(dct, 'gradients_qX_expectations', _slice_gradients_qX_expectations)
|
||||
return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct)
|
||||
|
||||
class _Slice_wrap(object):
|
||||
def __init__(self, instance, f):
|
||||
self.k = instance
|
||||
self.f = f
|
||||
def copy_to(self, new_instance):
|
||||
return self.__class__(new_instance, self.f)
|
||||
def _slice_X(self, X):
|
||||
return self.k._slice_X(X) if not self.k._sliced_X else X
|
||||
def _slice_X_X2(self, X, X2):
|
||||
return self.k._slice_X(X) if not self.k._sliced_X else X, self.k._slice_X(X2) if X2 is not None and not self.k._sliced_X else X2
|
||||
def __init__(self, k, X, X2=None):
|
||||
self.k = k
|
||||
self.shape = X.shape
|
||||
if self.k._sliced_X == 0:
|
||||
self.X = self.k._slice_X(X)
|
||||
self.X2 = self.k._slice_X(X2) if X2 is not None else None
|
||||
self.ret = True
|
||||
else:
|
||||
self.X = X
|
||||
self.X2 = X2
|
||||
self.ret = False
|
||||
def __enter__(self):
|
||||
self.k._sliced_X += 1
|
||||
return self
|
||||
def __exit__(self, *a):
|
||||
self.k._sliced_X -= 1
|
||||
def handle_return_array(self, return_val):
|
||||
if self.ret:
|
||||
ret = np.zeros(self.shape)
|
||||
ret[:, self.k.active_dims] = return_val
|
||||
return ret
|
||||
return return_val
|
||||
|
||||
class _Slice_wrapper(_Slice_wrap):
|
||||
def __call__(self, X, X2 = None, *a, **kw):
|
||||
X, X2 = self._slice_X_X2(X, X2)
|
||||
with self:
|
||||
ret = self.f(X, X2, *a, **kw)
|
||||
def _slice_K(f):
|
||||
@wraps(f)
|
||||
def wrap(self, X, X2 = None, *a, **kw):
|
||||
with _Slice_wrap(self, X, X2) as s:
|
||||
ret = f(self, s.X, s.X2, *a, **kw)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_diag(_Slice_wrap):
|
||||
def __call__(self, X, *a, **kw):
|
||||
X = self._slice_X(X)
|
||||
with self:
|
||||
ret = self.f(X, *a, **kw)
|
||||
def _slice_Kdiag(f):
|
||||
@wraps(f)
|
||||
def wrap(self, X, *a, **kw):
|
||||
with _Slice_wrap(self, X, None) as s:
|
||||
ret = f(self, s.X, *a, **kw)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_derivative(_Slice_wrap):
|
||||
def __call__(self, dL_dK, X, X2=None):
|
||||
self._slice_X(X)
|
||||
with self:
|
||||
ret = self.f(dL_dK, X, X2)
|
||||
def _slice_update_gradients_full(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dK, X, X2=None):
|
||||
with _Slice_wrap(self, X, X2) as s:
|
||||
ret = f(self, dL_dK, s.X, s.X2)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_diag_derivative(_Slice_wrap):
|
||||
def __call__(self, dL_dKdiag, X):
|
||||
X = self._slice_X(X)
|
||||
with self:
|
||||
ret = self.f(dL_dKdiag, X)
|
||||
def _slice_update_gradients_diag(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dKdiag, X):
|
||||
with _Slice_wrap(self, X, None) as s:
|
||||
ret = f(self, dL_dKdiag, s.X)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_grad_X(_Slice_wrap):
|
||||
def __call__(self, dL_dK, X, X2=None):
|
||||
ret = np.zeros(X.shape)
|
||||
X, X2 = self._slice_X_X2(X, X2)
|
||||
with self:
|
||||
ret[:, self.k.active_dims] = self.f(dL_dK, X, X2)
|
||||
def _slice_gradients_X(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dK, X, X2=None):
|
||||
with _Slice_wrap(self, X, X2) as s:
|
||||
ret = s.handle_return_array(f(self, dL_dK, s.X, s.X2))
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_grad_X_diag(_Slice_wrap):
|
||||
def __call__(self, dL_dKdiag, X):
|
||||
ret = np.zeros(X.shape)
|
||||
X = self._slice_X(X)
|
||||
with self:
|
||||
ret[:, self.k.active_dims] = self.f(dL_dKdiag, X)
|
||||
def _slice_gradients_X_diag(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dKdiag, X):
|
||||
with _Slice_wrap(self, X, None) as s:
|
||||
ret = s.handle_return_array(f(self, dL_dKdiag, s.X))
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_psi_stat_derivative_no_ret(_Slice_wrap):
|
||||
def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
|
||||
with self:
|
||||
ret = self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
def _slice_psi(f):
|
||||
@wraps(f)
|
||||
def wrap(self, Z, variational_posterior):
|
||||
with _Slice_wrap(self, Z, variational_posterior) as s:
|
||||
ret = f(self, s.X, s.X2)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
class _Slice_wrapper_psi_stat_derivative(_Slice_wrap):
|
||||
def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
|
||||
Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
|
||||
with self:
|
||||
ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
|
||||
def _slice_update_gradients_expectations(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
with _Slice_wrap(self, Z, variational_posterior) as s:
|
||||
ret = f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X, s.X2)
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
def _slice_gradients_Z_expectations(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
with _Slice_wrap(self, Z, variational_posterior) as s:
|
||||
ret = s.handle_return_array(f(self, dL_dpsi1, dL_dpsi2, s.X, s.X2))
|
||||
return ret
|
||||
return wrap
|
||||
|
||||
def _slice_gradients_qX_expectations(f):
|
||||
@wraps(f)
|
||||
def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
with _Slice_wrap(self, variational_posterior, Z) as s:
|
||||
ret = list(f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X2, s.X))
|
||||
r2 = ret[:2]
|
||||
ret[0] = ret1
|
||||
ret[1] = ret2
|
||||
ret[0][:, self.k.active_dims] = r2[0]
|
||||
ret[1][:, self.k.active_dims] = r2[1]
|
||||
del r2
|
||||
return ret
|
||||
|
||||
class _Slice_wrapper_psi_stat_derivative_Z(_Slice_wrap):
|
||||
def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
|
||||
Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
|
||||
with self:
|
||||
ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
|
||||
r2 = ret[:2]
|
||||
ret[0] = ret1
|
||||
ret[1] = ret2
|
||||
ret[0][:, self.k.active_dims] = r2[0]
|
||||
ret[1][:, self.k.active_dims] = r2[1]
|
||||
ret[0] = s.handle_return_array(r2[0])
|
||||
ret[1] = s.handle_return_array(r2[1])
|
||||
del r2
|
||||
return ret
|
||||
return wrap
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class Cacher(object):
|
|||
def __name__(self):
|
||||
return self.operation.__name__
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, update_wrapper
|
||||
|
||||
class Cacher_wrap(object):
|
||||
def __init__(self, f, limit, ignore_args, force_kwargs):
|
||||
|
|
@ -109,6 +109,7 @@ class Cacher_wrap(object):
|
|||
self.ignore_args = ignore_args
|
||||
self.force_kwargs = force_kwargs
|
||||
self.f = f
|
||||
update_wrapper(self, self.f)
|
||||
def __get__(self, obj, objtype=None):
|
||||
return partial(self, obj)
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue