mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
slice operations now bound functions, not added after the fact
This commit is contained in:
parent
ebb919bb8b
commit
a126f288d2
1 changed files with 37 additions and 35 deletions
|
|
@ -6,23 +6,25 @@ Created on 11 Mar 2014
|
||||||
from ...core.parameterization.parameterized import ParametersChangedMeta
|
from ...core.parameterization.parameterized import ParametersChangedMeta
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
def put_clean(dct, name, *args, **kw):
|
||||||
|
if name in dct:
|
||||||
|
dct['_clean_{}'.format(name)] = dct[name]
|
||||||
|
dct[name] = _slice_wrapper(None, dct[name], *args, **kw)
|
||||||
|
|
||||||
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
||||||
def __call__(self, *args, **kw):
|
def __new__(cls, name, bases, dct):
|
||||||
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
put_clean(dct, 'K')
|
||||||
instance.K = _slice_wrapper(instance, instance.K)
|
put_clean(dct, 'Kdiag', diag=True)
|
||||||
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True)
|
put_clean(dct, 'update_gradients_full', diag=False, derivative=True)
|
||||||
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True)
|
put_clean(dct, 'gradients_X', diag=False, derivative=True, ret_X=True)
|
||||||
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True)
|
put_clean(dct, 'gradients_X_diag', diag=True, derivative=True, ret_X=True)
|
||||||
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True, ret_X=True)
|
put_clean(dct, 'psi0', diag=False, derivative=False)
|
||||||
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True, ret_X=True)
|
put_clean(dct, 'psi1', diag=False, derivative=False)
|
||||||
instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False)
|
put_clean(dct, 'psi2', diag=False, derivative=False)
|
||||||
instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False)
|
put_clean(dct, 'update_gradients_expectations', derivative=True, psi_stat=True)
|
||||||
instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False)
|
put_clean(dct, 'gradients_Z_expectations', derivative=True, psi_stat_Z=True, ret_X=True)
|
||||||
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True)
|
put_clean(dct, 'gradients_qX_expectations', derivative=True, psi_stat=True, ret_X=True)
|
||||||
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True, ret_X=True)
|
return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct)
|
||||||
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True, ret_X=True)
|
|
||||||
instance.parameters_changed()
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False, psi_stat_Z=False, ret_X=False):
|
def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False, psi_stat_Z=False, ret_X=False):
|
||||||
"""
|
"""
|
||||||
|
|
@ -35,7 +37,7 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
||||||
"""
|
"""
|
||||||
if derivative:
|
if derivative:
|
||||||
if diag:
|
if diag:
|
||||||
def x_slice_wrapper(dL_dKdiag, X):
|
def x_slice_wrapper(kern, dL_dKdiag, X):
|
||||||
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
||||||
if ret_X_not_sliced:
|
if ret_X_not_sliced:
|
||||||
ret = np.zeros(X.shape)
|
ret = np.zeros(X.shape)
|
||||||
|
|
@ -43,15 +45,15 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
||||||
# if the return value is of shape X.shape, we need to make sure to return the right shape
|
# if the return value is of shape X.shape, we need to make sure to return the right shape
|
||||||
kern._sliced_X += 1
|
kern._sliced_X += 1
|
||||||
try:
|
try:
|
||||||
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dKdiag, X)
|
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(kern, dL_dKdiag, X)
|
||||||
else: ret = operation(dL_dKdiag, X)
|
else: ret = operation(kern, dL_dKdiag, X)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
kern._sliced_X -= 1
|
kern._sliced_X -= 1
|
||||||
return ret
|
return ret
|
||||||
elif psi_stat:
|
elif psi_stat:
|
||||||
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def x_slice_wrapper(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
||||||
if ret_X_not_sliced:
|
if ret_X_not_sliced:
|
||||||
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
|
ret1, ret2 = np.zeros(variational_posterior.shape), np.zeros(variational_posterior.shape)
|
||||||
|
|
@ -60,44 +62,44 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
||||||
# if the return value is of shape X.shape, we need to make sure to return the right shape
|
# if the return value is of shape X.shape, we need to make sure to return the right shape
|
||||||
try:
|
try:
|
||||||
if ret_X_not_sliced:
|
if ret_X_not_sliced:
|
||||||
ret = list(operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
|
ret = list(operation(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior))
|
||||||
r2 = ret[:2]
|
r2 = ret[:2]
|
||||||
ret[0] = ret1
|
ret[0] = ret1
|
||||||
ret[1] = ret2
|
ret[1] = ret2
|
||||||
ret[0][:, kern.active_dims] = r2[0]
|
ret[0][:, kern.active_dims] = r2[0]
|
||||||
ret[1][:, kern.active_dims] = r2[1]
|
ret[1][:, kern.active_dims] = r2[1]
|
||||||
del r2
|
del r2
|
||||||
else: ret = operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
else: ret = operation(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
kern._sliced_X -= 1
|
kern._sliced_X -= 1
|
||||||
return ret
|
return ret
|
||||||
elif psi_stat_Z:
|
elif psi_stat_Z:
|
||||||
def x_slice_wrapper(dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def x_slice_wrapper(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
||||||
if ret_X_not_sliced: ret = np.zeros(Z.shape)
|
if ret_X_not_sliced: ret = np.zeros(Z.shape)
|
||||||
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
||||||
kern._sliced_X += 1
|
kern._sliced_X += 1
|
||||||
try:
|
try:
|
||||||
if ret_X_not_sliced:
|
if ret_X_not_sliced:
|
||||||
ret[:, kern.active_dims] = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
ret[:, kern.active_dims] = operation(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||||
else: ret = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
else: ret = operation(kern, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
kern._sliced_X -= 1
|
kern._sliced_X -= 1
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
def x_slice_wrapper(dL_dK, X, X2=None):
|
def x_slice_wrapper(kern, dL_dK, X, X2=None):
|
||||||
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
ret_X_not_sliced = ret_X and kern._sliced_X == 0
|
||||||
if ret_X_not_sliced:
|
if ret_X_not_sliced:
|
||||||
ret = np.zeros(X.shape)
|
ret = np.zeros(X.shape)
|
||||||
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
|
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
|
||||||
kern._sliced_X += 1
|
kern._sliced_X += 1
|
||||||
try:
|
try:
|
||||||
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(dL_dK, X, X2)
|
if ret_X_not_sliced: ret[:, kern.active_dims] = operation(kern, dL_dK, X, X2)
|
||||||
else: ret = operation(dL_dK, X, X2)
|
else: ret = operation(kern, dL_dK, X, X2)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -105,30 +107,30 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
if diag:
|
if diag:
|
||||||
def x_slice_wrapper(X, *args, **kw):
|
def x_slice_wrapper(kern, X, *args, **kw):
|
||||||
X = kern._slice_X(X) if not kern._sliced_X else X
|
X = kern._slice_X(X) if not kern._sliced_X else X
|
||||||
kern._sliced_X += 1
|
kern._sliced_X += 1
|
||||||
try:
|
try:
|
||||||
ret = operation(X, *args, **kw)
|
ret = operation(kern, X, *args, **kw)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
kern._sliced_X -= 1
|
kern._sliced_X -= 1
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
def x_slice_wrapper(X, X2=None, *args, **kw):
|
def x_slice_wrapper(kern, X, X2=None, *args, **kw):
|
||||||
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
|
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
|
||||||
kern._sliced_X += 1
|
kern._sliced_X += 1
|
||||||
try:
|
try:
|
||||||
ret = operation(X, X2, *args, **kw)
|
ret = operation(kern, X, X2, *args, **kw)
|
||||||
except: raise
|
except: raise
|
||||||
finally:
|
finally:
|
||||||
kern._sliced_X -= 1
|
kern._sliced_X -= 1
|
||||||
return ret
|
return ret
|
||||||
x_slice_wrapper._operation = operation
|
x_slice_wrapper._operation = operation
|
||||||
x_slice_wrapper.__name__ = ("slicer("+str(operation)
|
x_slice_wrapper.__name__ = ("slicer("+str(operation)
|
||||||
+(","+str(bool(diag)) if diag else'')
|
+(","+str('diag') if diag else'')
|
||||||
+(','+str(bool(derivative)) if derivative else '')
|
+(','+str('derivative') if derivative else '')
|
||||||
+')')
|
+')')
|
||||||
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
|
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
|
||||||
return x_slice_wrapper
|
return x_slice_wrapper
|
||||||
Loading…
Add table
Add a link
Reference in a new issue