kern merge commencing

This commit is contained in:
Max Zwiessele 2014-03-27 08:05:22 +00:00
parent 1294c24a28
commit f8ff2c7df2

View file

@ -5,130 +5,121 @@ Created on 11 Mar 2014
''' '''
from ...core.parameterization.parameterized import ParametersChangedMeta from ...core.parameterization.parameterized import ParametersChangedMeta
import numpy as np import numpy as np
import functools
class KernCallsViaSlicerMeta(ParametersChangedMeta): class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw): def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw) instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.K = _slice_wrapper(instance, instance.K) instance.K = _Slice_wrapper(instance, instance.K)
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True) instance.Kdiag = _Slice_wrapper_diag(instance, instance.Kdiag)
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True)
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True) instance.update_gradients_full = _Slice_wrapper_derivative(instance, instance.update_gradients_full)
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True, ret_X=True) instance.update_gradients_diag = _Slice_wrapper_diag_derivative(instance, instance.update_gradients_diag)
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True, ret_X=True)
instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False) instance.gradients_X = _Slice_wrapper_grad_X(instance, instance.gradients_X)
instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False) instance.gradients_X_diag = _Slice_wrapper_grad_X_diag(instance, instance.gradients_X_diag)
instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False)
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True) instance.psi0 = _Slice_wrapper(instance, instance.psi0)
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True, ret_X=True) instance.psi1 = _Slice_wrapper(instance, instance.psi1)
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True, ret_X=True) instance.psi2 = _Slice_wrapper(instance, instance.psi2)
instance.update_gradients_expectations = _Slice_wrapper_psi_stat_derivative_no_ret(instance, instance.update_gradients_expectations)
instance.gradients_Z_expectations = _Slice_wrapper_psi_stat_derivative_Z(instance, instance.gradients_Z_expectations)
instance.gradients_qX_expectations = _Slice_wrapper_psi_stat_derivative(instance, instance.gradients_qX_expectations)
instance.parameters_changed() instance.parameters_changed()
return instance return instance
def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False, psi_stat_Z=False, ret_X=False): class _Slice_wrap(object):
""" def __init__(self, instance, f):
This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension. self.k = instance
The different switches are: self.f = f
diag: if X2 exists def copy_to(self, new_instance):
derivative: if first arg is dL_dK return self.__class__(new_instance, self.f)
psi_stat: if first 3 args are dL_dpsi0..2 def _slice_X(self, X):
psi_stat_Z: if first 2 args are dL_dpsi1..2 return self.k._slice_X(X) if not self.k._sliced_X else X
""" def _slice_X_X2(self, X, X2):
if derivative: return self.k._slice_X(X) if not self.k._sliced_X else X, self.k._slice_X(X2) if X2 is not None and not self.k._sliced_X else X2
if diag: def __enter__(self):
def x_slice_wrapper(dL_dKdiag, X): self.k._sliced_X += 1
ret_X_not_sliced = ret_X and kern._sliced_X == 0 return self
if ret_X_not_sliced: def __exit__(self, *a):
ret = np.zeros(X.shape) self.k._sliced_X -= 1
X = kern._slice_X(X) if not kern._sliced_X else X
# if the return value is of shape X.shape, we need to make sure to return the right shape class _Slice_wrapper(_Slice_wrap):
kern._sliced_X += 1 def __call__(self, X, X2 = None, *a, **kw):
try: X, X2 = self._slice_X_X2(X, X2)
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dKdiag, X) with self:
else: ret = operation(dL_dKdiag, X) ret = self.f(X, X2, *a, **kw)
except: return ret
raise
finally: class _Slice_wrapper_diag(_Slice_wrap):
kern._sliced_X -= 1 def __call__(self, X, *a, **kw):
return ret X = self._slice_X(X)
elif psi_stat: with self:
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): ret = self.f(X, *a, **kw)
ret_X_not_sliced = ret_X and kern._sliced_X == 0 return ret
if ret_X_not_sliced:
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape) class _Slice_wrapper_derivative(_Slice_wrap):
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior def __call__(self, dL_dK, X, X2=None):
kern._sliced_X += 1 self._slice_X(X)
# if the return value is of shape X.shape, we need to make sure to return the right shape with self:
try: ret = self.f(dL_dK, X, X2)
if ret_X_not_sliced: return ret
ret = list(operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
r2 = ret[:2] class _Slice_wrapper_diag_derivative(_Slice_wrap):
ret[0] = ret1 def __call__(self, dL_dKdiag, X):
ret[1] = ret2 X = self._slice_X(X)
ret[0][:, kern.active_dims] = r2[0] with self:
ret[1][:, kern.active_dims] = r2[1] ret = self.f(dL_dKdiag, X)
del r2 return ret
else: ret = operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
except: class _Slice_wrapper_grad_X(_Slice_wrap):
raise def __call__(self, dL_dK, X, X2=None):
finally: ret = np.zeros(X.shape)
kern._sliced_X -= 1 X, X2 = self._slice_X_X2(X, X2)
return ret with self:
elif psi_stat_Z: ret[:, self.k.active_dims] = self.f(dL_dK, X, X2)
def x_slice_wrapper(dL_dpsi1, dL_dpsi2, Z, variational_posterior): return ret
ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: ret = np.zeros(Z.shape) class _Slice_wrapper_grad_X_diag(_Slice_wrap):
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior def __call__(self, dL_dKdiag, X):
kern._sliced_X += 1 ret = np.zeros(X.shape)
try: X = self._slice_X(X)
if ret_X_not_sliced: with self:
ret[:, kern.active_dims] = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior) ret[:, self.k.active_dims] = self.f(dL_dKdiag, X)
else: ret = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior) return ret
except:
raise class _Slice_wrapper_psi_stat_derivative_no_ret(_Slice_wrap):
finally: def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
kern._sliced_X -= 1 Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
return ret with self:
else: ret = self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
def x_slice_wrapper(dL_dK, X, X2=None): return ret
ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: class _Slice_wrapper_psi_stat_derivative(_Slice_wrap):
ret = np.zeros(X.shape) def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2 ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
kern._sliced_X += 1 Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
try: with self:
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dK, X, X2) ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
else: ret = operation(dL_dK, X, X2) r2 = ret[:2]
except: ret[0] = ret1
raise ret[1] = ret2
finally: ret[0][:, self.k.active_dims] = r2[0]
kern._sliced_X -= 1 ret[1][:, self.k.active_dims] = r2[1]
return ret del r2
else: return ret
if diag:
def x_slice_wrapper(X, *args, **kw): class _Slice_wrapper_psi_stat_derivative_Z(_Slice_wrap):
X = kern._slice_X(X) if not kern._sliced_X else X def __call__(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
kern._sliced_X += 1 ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
try: Z, variational_posterior = self._slice_X_X2(Z, variational_posterior)
ret = operation(X, *args, **kw) with self:
except: ret = list(self.f(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
raise r2 = ret[:2]
finally: ret[0] = ret1
kern._sliced_X -= 1 ret[1] = ret2
return ret ret[0][:, self.k.active_dims] = r2[0]
else: ret[1][:, self.k.active_dims] = r2[1]
def x_slice_wrapper(X, X2=None, *args, **kw): del r2
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2 return ret
kern._sliced_X += 1
try:
ret = operation(X, X2, *args, **kw)
except: raise
finally:
kern._sliced_X -= 1
return ret
x_slice_wrapper._operation = operation
x_slice_wrapper.__name__ = ("slicer("+str(operation)
+(","+str(bool(diag)) if diag else'')
+(','+str(bool(derivative)) if derivative else '')
+')')
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
return x_slice_wrapper