From e26e7370141113c5ac5ecbac34824819a6c941ab Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Thu, 27 Mar 2014 09:28:44 +0000 Subject: [PATCH] new slicing done and first attempts at copy and pickling full models --- GPy/kern/_src/kernel_slice_operations.py | 186 ++++++++++++----------- GPy/util/caching.py | 3 +- 2 files changed, 103 insertions(+), 86 deletions(-) diff --git a/GPy/kern/_src/kernel_slice_operations.py b/GPy/kern/_src/kernel_slice_operations.py index 6620f28c..21421cc0 100644 --- a/GPy/kern/_src/kernel_slice_operations.py +++ b/GPy/kern/_src/kernel_slice_operations.py @@ -5,118 +5,134 @@ Created on 11 Mar 2014 ''' from ...core.parameterization.parameterized import ParametersChangedMeta import numpy as np +from functools import wraps -def put_clean(dct, name, *args, **kw): +def put_clean(dct, name, func): if name in dct: dct['_clean_{}'.format(name)] = dct[name] - dct[name] = _slice_wrapper(None, dct[name], *args, **kw) - + dct[name] = func(dct[name]) + class KernCallsViaSlicerMeta(ParametersChangedMeta): def __new__(cls, name, bases, dct): - put_clean(dct, 'K') - put_clean(dct, 'Kdiag', diag=True) - put_clean(dct, 'update_gradients_full', diag=False, derivative=True) - put_clean(dct, 'gradients_X', diag=False, derivative=True, ret_X=True) - put_clean(dct, 'gradients_X_diag', diag=True, derivative=True, ret_X=True) - put_clean(dct, 'psi0', diag=False, derivative=False) - put_clean(dct, 'psi1', diag=False, derivative=False) - put_clean(dct, 'psi2', diag=False, derivative=False) - put_clean(dct, 'update_gradients_expectations', derivative=True, psi_stat=True) - put_clean(dct, 'gradients_Z_expectations', derivative=True, psi_stat_Z=True, ret_X=True) - put_clean(dct, 'gradients_qX_expectations', derivative=True, psi_stat=True, ret_X=True) + put_clean(dct, 'K', _slice_K) + put_clean(dct, 'Kdiag', _slice_Kdiag) + put_clean(dct, 'update_gradients_full', _slice_update_gradients_full) + put_clean(dct, 'update_gradients_diag', _slice_update_gradients_diag) + put_clean(dct, 'gradients_X', _slice_gradients_X) + put_clean(dct, 'gradients_X_diag', _slice_gradients_X_diag) + + put_clean(dct, 'psi0', _slice_psi) + put_clean(dct, 'psi1', _slice_psi) + put_clean(dct, 'psi2', _slice_psi) + put_clean(dct, 'update_gradients_expectations', _slice_update_gradients_expectations) + put_clean(dct, 'gradients_Z_expectations', _slice_gradients_Z_expectations) + put_clean(dct, 'gradients_qX_expectations', _slice_gradients_qX_expectations) return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct) - + class _Slice_wrap(object): - def __init__(self, instance, f): - self.k = instance - self.f = f - def copy_to(self, new_instance): - return self.__class__(new_instance, self.f) - def _slice_X(self, X): - return self.k._slice_X(X) if not self.k._sliced_X else X - def _slice_X_X2(self, X, X2): - return self.k._slice_X(X) if not self.k._sliced_X else X, self.k._slice_X(X2) if X2 is not None and not self.k._sliced_X else X2 + def __init__(self, k, X, X2=None): + self.k = k + self.shape = X.shape + if self.k._sliced_X == 0: + self.X = self.k._slice_X(X) + self.X2 = self.k._slice_X(X2) if X2 is not None else None + self.ret = True + else: + self.X = X + self.X2 = X2 + self.ret = False def __enter__(self): self.k._sliced_X += 1 return self def __exit__(self, *a): self.k._sliced_X -= 1 + def handle_return_array(self, return_val): + if self.ret: + ret = np.zeros(self.shape) + ret[:, self.k.active_dims] = return_val + return ret + return return_val -class _Slice_wrapper(_Slice_wrap): - def __call__(self, X, X2 = None, *a, **kw): - X, X2 = self._slice_X_X2(X, X2) - with self: - ret = self.f(X, X2, *a, **kw) +def _slice_K(f): + @wraps(f) + def wrap(self, X, X2 = None, *a, **kw): + with _Slice_wrap(self, X, X2) as s: + ret = f(self, s.X, s.X2, *a, **kw) return ret + return wrap -class _Slice_wrapper_diag(_Slice_wrap): - def __call__(self, X, *a, **kw): - X = self._slice_X(X) - with self: - ret = self.f(X, *a, **kw) +def _slice_Kdiag(f): + @wraps(f) + def wrap(self, X, *a, **kw): + with _Slice_wrap(self, X, None) as s: + ret = f(self, s.X, *a, **kw) return ret + return wrap -class _Slice_wrapper_derivative(_Slice_wrap): - def __call__(self, dL_dK, X, X2=None): - self._slice_X(X) - with self: - ret = self.f(dL_dK, X, X2) +def _slice_update_gradients_full(f): + @wraps(f) + def wrap(self, dL_dK, X, X2=None): + with _Slice_wrap(self, X, X2) as s: + ret = f(self, dL_dK, s.X, s.X2) return ret + return wrap -class _Slice_wrapper_diag_derivative(_Slice_wrap): - def __call__(self, dL_dKdiag, X): - X = self._slice_X(X) - with self: - ret = self.f(dL_dKdiag, X) +def _slice_update_gradients_diag(f): + @wraps(f) + def wrap(self, dL_dKdiag, X): + with _Slice_wrap(self, X, None) as s: + ret = f(self, dL_dKdiag, s.X) return ret + return wrap -class _Slice_wrapper_grad_X(_Slice_wrap): - def __call__(self, dL_dK, X, X2=None): - ret = np.zeros(X.shape) - X, X2 = self._slice_X_X2(X, X2) - with self: - ret[:, self.k.active_dims] = self.f(dL_dK, X, X2) +def _slice_gradients_X(f): + @wraps(f) + def wrap(self, dL_dK, X, X2=None): + with _Slice_wrap(self, X, X2) as s: + ret = s.handle_return_array(f(self, dL_dK, s.X, s.X2)) return ret + return wrap -class _Slice_wrapper_grad_X_diag(_Slice_wrap): - def __call__(self, dL_dKdiag, X): - ret = np.zeros(X.shape) - X = self._slice_X(X) - with self: - ret[:, self.k.active_dims] = self.f(dL_dKdiag, X) +def _slice_gradients_X_diag(f): + @wraps(f) + def wrap(self, dL_dKdiag, X): + with _Slice_wrap(self, X, None) as s: + ret = s.handle_return_array(f(self, dL_dKdiag, s.X)) return ret + return wrap -class _Slice_wrapper_psi_stat_derivative_no_ret(_Slice_wrap): - def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): - Z, variational_posterior = self._slice_X_X2(Z, variational_posterior) - with self: - ret = self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior) +def _slice_psi(f): + @wraps(f) + def wrap(self, Z, variational_posterior): + with _Slice_wrap(self, Z, variational_posterior) as s: + ret = f(self, s.X, s.X2) return ret + return wrap -class _Slice_wrapper_psi_stat_derivative(_Slice_wrap): - def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): - ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape) - Z, variational_posterior = self._slice_X_X2(Z, variational_posterior) - with self: - ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)) +def _slice_update_gradients_expectations(f): + @wraps(f) + def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): + with _Slice_wrap(self, Z, variational_posterior) as s: + ret = f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X, s.X2) + return ret + return wrap + +def _slice_gradients_Z_expectations(f): + @wraps(f) + def wrap(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior): + with _Slice_wrap(self, Z, variational_posterior) as s: + ret = s.handle_return_array(f(self, dL_dpsi1, dL_dpsi2, s.X, s.X2)) + return ret + return wrap + +def _slice_gradients_qX_expectations(f): + @wraps(f) + def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): + with _Slice_wrap(self, variational_posterior, Z) as s: + ret = list(f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X2, s.X)) r2 = ret[:2] - ret[0] = ret1 - ret[1] = ret2 - ret[0][:, self.k.active_dims] = r2[0] - ret[1][:, self.k.active_dims] = r2[1] - del r2 - return ret - -class _Slice_wrapper_psi_stat_derivative_Z(_Slice_wrap): - def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): - ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape) - Z, variational_posterior = self._slice_X_X2(Z, variational_posterior) - with self: - ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)) - r2 = ret[:2] - ret[0] = ret1 - ret[1] = ret2 - ret[0][:, self.k.active_dims] = r2[0] - ret[1][:, self.k.active_dims] = r2[1] + ret[0] = s.handle_return_array(r2[0]) + ret[1] = s.handle_return_array(r2[1]) del r2 return ret + return wrap diff --git a/GPy/util/caching.py b/GPy/util/caching.py index fcb0b726..0886d0c6 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -101,7 +101,7 @@ class Cacher(object): def __name__(self): return self.operation.__name__ -from functools import partial +from functools import partial, update_wrapper class Cacher_wrap(object): def __init__(self, f, limit, ignore_args, force_kwargs): @@ -109,6 +109,7 @@ class Cacher_wrap(object): self.ignore_args = ignore_args self.force_kwargs = force_kwargs self.f = f + update_wrapper(self, self.f) def __get__(self, obj, objtype=None): return partial(self, obj) def __call__(self, *args, **kwargs):