mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
product kernel and combination kernel updates
This commit is contained in:
parent
5fb0acbdb4
commit
e9c96632ba
4 changed files with 58 additions and 54 deletions
|
|
@ -9,17 +9,17 @@ class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
|||
def __call__(self, *args, **kw):
|
||||
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
||||
instance.K = _slice_wrapper(instance, instance.K)
|
||||
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, True)
|
||||
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, False, True)
|
||||
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, True, True)
|
||||
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, False, True)
|
||||
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, True, True)
|
||||
instance.psi0 = _slice_wrapper(instance, instance.psi0, False, False)
|
||||
instance.psi1 = _slice_wrapper(instance, instance.psi1, False, False)
|
||||
instance.psi2 = _slice_wrapper(instance, instance.psi2, False, False)
|
||||
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, psi_stat=True)
|
||||
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, psi_stat_Z=True)
|
||||
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, psi_stat=True)
|
||||
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True)
|
||||
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True)
|
||||
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True)
|
||||
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True)
|
||||
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True)
|
||||
instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False)
|
||||
instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False)
|
||||
instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False)
|
||||
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True)
|
||||
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True)
|
||||
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True)
|
||||
instance.parameters_changed()
|
||||
return instance
|
||||
|
||||
|
|
@ -44,7 +44,29 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
|||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
else:
|
||||
elif psi_stat:
|
||||
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
||||
kern._sliced_X += 1
|
||||
try:
|
||||
ret = operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
elif psi_stat_Z:
|
||||
def x_slice_wrapper(dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
||||
kern._sliced_X += 1
|
||||
try:
|
||||
ret = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
else:
|
||||
def x_slice_wrapper(dL_dK, X, X2=None):
|
||||
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
|
||||
kern._sliced_X += 1
|
||||
|
|
@ -55,28 +77,6 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
|
|||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
elif psi_stat:
|
||||
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
||||
kern._sliced_X += 1
|
||||
try:
|
||||
ret = operation(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
elif psi_stat_Z:
|
||||
def x_slice_wrapper(dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
|
||||
kern._sliced_X += 1
|
||||
try:
|
||||
ret = operation(dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
kern._sliced_X -= 1
|
||||
return ret
|
||||
else:
|
||||
if diag:
|
||||
def x_slice_wrapper(X, *args, **kw):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue