slice operations now bound functions, not added after the fact

This commit is contained in:
mzwiessele 2014-03-26 14:59:38 +00:00
parent ebb919bb8b
commit a126f288d2

View file

@ -6,24 +6,26 @@ Created on 11 Mar 2014
from ...core.parameterization.parameterized import ParametersChangedMeta from ...core.parameterization.parameterized import ParametersChangedMeta
import numpy as np import numpy as np
def put_clean(dct, name, *args, **kw):
if name in dct:
dct['_clean_{}'.format(name)] = dct[name]
dct[name] = _slice_wrapper(None, dct[name], *args, **kw)
class KernCallsViaSlicerMeta(ParametersChangedMeta): class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw): def __new__(cls, name, bases, dct):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw) put_clean(dct, 'K')
instance.K = _slice_wrapper(instance, instance.K) put_clean(dct, 'Kdiag', diag=True)
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True) put_clean(dct, 'update_gradients_full', diag=False, derivative=True)
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True) put_clean(dct, 'gradients_X', diag=False, derivative=True, ret_X=True)
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True) put_clean(dct, 'gradients_X_diag', diag=True, derivative=True, ret_X=True)
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True, ret_X=True) put_clean(dct, 'psi0', diag=False, derivative=False)
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True, ret_X=True) put_clean(dct, 'psi1', diag=False, derivative=False)
instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False) put_clean(dct, 'psi2', diag=False, derivative=False)
instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False) put_clean(dct, 'update_gradients_expectations', derivative=True, psi_stat=True)
instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False) put_clean(dct, 'gradients_Z_expectations', derivative=True, psi_stat_Z=True, ret_X=True)
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True) put_clean(dct, 'gradients_qX_expectations', derivative=True, psi_stat=True, ret_X=True)
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True, ret_X=True) return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct)
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True, ret_X=True)
instance.parameters_changed()
return instance
def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False, psi_stat_Z=False, ret_X=False): def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False, psi_stat_Z=False, ret_X=False):
""" """
This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension. This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension.
@ -35,7 +37,7 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
""" """
if derivative: if derivative:
if diag: if diag:
def x_slice_wrapper(dL_dKdiag, X): def x_slice_wrapper(kern, dL_dKdiag, X):
ret_X_not_sliced = ret_X and kern._sliced_X == 0 ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: if ret_X_not_sliced:
ret = np.zeros(X.shape) ret = np.zeros(X.shape)
@ -43,15 +45,15 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
# if the return value is of shape X.shape, we need to make sure to return the right shape # if the return value is of shape X.shape, we need to make sure to return the right shape
kern._sliced_X += 1 kern._sliced_X += 1
try: try:
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dKdiag, X) if ret_X_not_sliced: ret[:, kern.active_dims] = operation(kern, dL_dKdiag, X)
else: ret = operation(dL_dKdiag, X) else: ret = operation(kern, dL_dKdiag, X)
except: except:
raise raise
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
elif psi_stat: elif psi_stat:
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def x_slice_wrapper(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
ret_X_not_sliced = ret_X and kern._sliced_X == 0 ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: if ret_X_not_sliced:
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape) ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
@ -60,44 +62,44 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
# if the return value is of shape X.shape, we need to make sure to return the right shape # if the return value is of shape X.shape, we need to make sure to return the right shape
try: try:
if ret_X_not_sliced: if ret_X_not_sliced:
ret = list(operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)) ret = list(operation(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
r2 = ret[:2] r2 = ret[:2]
ret[0] = ret1 ret[0] = ret1
ret[1] = ret2 ret[1] = ret2
ret[0][:, kern.active_dims] = r2[0] ret[0][:, kern.active_dims] = r2[0]
ret[1][:, kern.active_dims] = r2[1] ret[1][:, kern.active_dims] = r2[1]
del r2 del r2
else: ret = operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior) else: ret = operation(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
except: except:
raise raise
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
elif psi_stat_Z: elif psi_stat_Z:
def x_slice_wrapper(dL_dpsi1, dL_dpsi2, Z, variational_posterior): def x_slice_wrapper(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
ret_X_not_sliced = ret_X and kern._sliced_X == 0 ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: ret = np.zeros(Z.shape) if ret_X_not_sliced: ret = np.zeros(Z.shape)
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
kern._sliced_X += 1 kern._sliced_X += 1
try: try:
if ret_X_not_sliced: if ret_X_not_sliced:
ret[:, kern.active_dims] = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior) ret[:, kern.active_dims] = operation(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
else: ret = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior) else: ret = operation(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
except: except:
raise raise
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
else: else:
def x_slice_wrapper(dL_dK, X, X2=None): def x_slice_wrapper(kern, dL_dK, X, X2=None):
ret_X_not_sliced = ret_X and kern._sliced_X == 0 ret_X_not_sliced = ret_X and kern._sliced_X == 0
if ret_X_not_sliced: if ret_X_not_sliced:
ret = np.zeros(X.shape) ret = np.zeros(X.shape)
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2 X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1 kern._sliced_X += 1
try: try:
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dK, X, X2) if ret_X_not_sliced: ret[:, kern.active_dims] = operation(kern, dL_dK, X, X2)
else: ret = operation(dL_dK, X, X2) else: ret = operation(kern, dL_dK, X, X2)
except: except:
raise raise
finally: finally:
@ -105,30 +107,30 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
return ret return ret
else: else:
if diag: if diag:
def x_slice_wrapper(X, *args, **kw): def x_slice_wrapper(kern, X, *args, **kw):
X = kern._slice_X(X) if not kern._sliced_X else X X = kern._slice_X(X) if not kern._sliced_X else X
kern._sliced_X += 1 kern._sliced_X += 1
try: try:
ret = operation(X, *args, **kw) ret = operation(kern, X, *args, **kw)
except: except:
raise raise
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
else: else:
def x_slice_wrapper(X, X2=None, *args, **kw): def x_slice_wrapper(kern, X, X2=None, *args, **kw):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2 X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1 kern._sliced_X += 1
try: try:
ret = operation(X, X2, *args, **kw) ret = operation(kern, X, X2, *args, **kw)
except: raise except: raise
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
x_slice_wrapper._operation = operation x_slice_wrapper._operation = operation
x_slice_wrapper.__name__ = ("slicer("+str(operation) x_slice_wrapper.__name__ = ("slicer("+str(operation)
+(","+str(bool(diag)) if diag else'') +(","+str('diag') if diag else'')
+(','+str(bool(derivative)) if derivative else '') +(','+str('derivative') if derivative else '')
+')') +')')
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "") x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
return x_slice_wrapper return x_slice_wrapper