mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
'''
|
|
Created on 11 Mar 2014
|
|
|
|
@author: maxz
|
|
'''
|
|
from ...core.parameterization.parameterized import ParametersChangedMeta
|
|
import numpy as np
|
|
from functools import wraps
|
|
|
|
def put_clean(dct, name, func):
|
|
if name in dct:
|
|
dct['_clean_{}'.format(name)] = dct[name]
|
|
dct[name] = func(dct[name])
|
|
|
|
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
|
def __new__(cls, name, bases, dct):
|
|
put_clean(dct, 'K', _slice_K)
|
|
put_clean(dct, 'Kdiag', _slice_Kdiag)
|
|
put_clean(dct, 'update_gradients_full', _slice_update_gradients_full)
|
|
put_clean(dct, 'update_gradients_diag', _slice_update_gradients_diag)
|
|
put_clean(dct, 'gradients_X', _slice_gradients_X)
|
|
put_clean(dct, 'gradients_X_diag', _slice_gradients_X_diag)
|
|
|
|
put_clean(dct, 'psi0', _slice_psi)
|
|
put_clean(dct, 'psi1', _slice_psi)
|
|
put_clean(dct, 'psi2', _slice_psi)
|
|
put_clean(dct, 'update_gradients_expectations', _slice_update_gradients_expectations)
|
|
put_clean(dct, 'gradients_Z_expectations', _slice_gradients_Z_expectations)
|
|
put_clean(dct, 'gradients_qX_expectations', _slice_gradients_qX_expectations)
|
|
return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct)
|
|
|
|
class _Slice_wrap(object):
|
|
def __init__(self, k, X, X2=None):
|
|
self.k = k
|
|
self.shape = X.shape
|
|
assert X.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X.shape={!s}".format(X.shape)
|
|
if X2 is not None:
|
|
assert X2.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X2.shape={!s}".format(X2.shape)
|
|
if (self.k.active_dims is not None) and (self.k._sliced_X == 0):
|
|
self.k._check_active_dims(X)
|
|
self.X = self.k._slice_X(X)
|
|
self.X2 = self.k._slice_X(X2) if X2 is not None else X2
|
|
self.ret = True
|
|
else:
|
|
self.k._check_input_dim(X)
|
|
self.X = X
|
|
self.X2 = X2
|
|
self.ret = False
|
|
def __enter__(self):
|
|
self.k._sliced_X += 1
|
|
return self
|
|
def __exit__(self, *a):
|
|
self.k._sliced_X -= 1
|
|
def handle_return_array(self, return_val):
|
|
if self.ret:
|
|
ret = np.zeros(self.shape)
|
|
ret[:, self.k.active_dims] = return_val
|
|
return ret
|
|
return return_val
|
|
|
|
def _slice_K(f):
|
|
@wraps(f)
|
|
def wrap(self, X, X2 = None, *a, **kw):
|
|
with _Slice_wrap(self, X, X2) as s:
|
|
ret = f(self, s.X, s.X2, *a, **kw)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_Kdiag(f):
|
|
@wraps(f)
|
|
def wrap(self, X, *a, **kw):
|
|
with _Slice_wrap(self, X, None) as s:
|
|
ret = f(self, s.X, *a, **kw)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_update_gradients_full(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dK, X, X2=None):
|
|
with _Slice_wrap(self, X, X2) as s:
|
|
ret = f(self, dL_dK, s.X, s.X2)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_update_gradients_diag(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dKdiag, X):
|
|
with _Slice_wrap(self, X, None) as s:
|
|
ret = f(self, dL_dKdiag, s.X)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_gradients_X(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dK, X, X2=None):
|
|
with _Slice_wrap(self, X, X2) as s:
|
|
ret = s.handle_return_array(f(self, dL_dK, s.X, s.X2))
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_gradients_X_diag(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dKdiag, X):
|
|
with _Slice_wrap(self, X, None) as s:
|
|
ret = s.handle_return_array(f(self, dL_dKdiag, s.X))
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_psi(f):
|
|
@wraps(f)
|
|
def wrap(self, Z, variational_posterior):
|
|
with _Slice_wrap(self, Z, variational_posterior) as s:
|
|
ret = f(self, s.X, s.X2)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_update_gradients_expectations(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
|
with _Slice_wrap(self, Z, variational_posterior) as s:
|
|
ret = f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X, s.X2)
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_gradients_Z_expectations(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
|
with _Slice_wrap(self, Z, variational_posterior) as s:
|
|
ret = s.handle_return_array(f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X, s.X2))
|
|
return ret
|
|
return wrap
|
|
|
|
def _slice_gradients_qX_expectations(f):
|
|
@wraps(f)
|
|
def wrap(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
|
with _Slice_wrap(self, variational_posterior, Z) as s:
|
|
ret = list(f(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, s.X2, s.X))
|
|
r2 = ret[:2]
|
|
ret[0] = s.handle_return_array(r2[0])
|
|
ret[1] = s.handle_return_array(r2[1])
|
|
del r2
|
|
return ret
|
|
return wrap
|