product kernel and combination kernel updates

This commit is contained in:
Max Zwiessele 2014-03-13 11:01:48 +00:00
parent 5fb0acbdb4
commit e9c96632ba
4 changed files with 58 additions and 54 deletions

View file

@ -3,7 +3,6 @@
import numpy as np
import itertools
from ...core.parameterization import Parameterized
from ...util.caching import Cache_this
from kern import CombinationKernel

View file

@ -156,7 +156,7 @@ class Kern(Parameterized):
other.active_dims += self.input_dim
return self.prod(other)
def prod(self, other, name=None):
def prod(self, other, name='prod'):
"""
Multiply two kernels (either on the same space, or on the tensor
product of the input space).
@ -169,12 +169,12 @@ class Kern(Parameterized):
"""
assert isinstance(other, Kern), "only kernels can be added to kernels..."
from prod import Prod
kernels = []
if isinstance(self, Prod): kernels.extend(self._parameters_)
else: kernels.append(self)
if isinstance(other, Prod): kernels.extend(other._parameters_)
else: kernels.append(other)
return Prod(self, other, name)
#kernels = []
#if isinstance(self, Prod): kernels.extend(self._parameters_)
#else: kernels.append(self)
#if isinstance(other, Prod): kernels.extend(other._parameters_)
#else: kernels.append(other)
return Prod([self, other], name)
def _getstate(self):
"""
@ -195,8 +195,10 @@ class Kern(Parameterized):
class CombinationKernel(Kern):
def __init__(self, kernels, name):
assert all([isinstance(k, Kern) for k in kernels])
# make sure the active dimensions of all underlying kernels are covered:
ma = reduce(lambda a,b: max(a, max(b)), (x.active_dims for x in kernels), 0)
input_dim = np.r_[0:ma+1]
# initialize the kernel with the full input_dim
super(CombinationKernel, self).__init__(input_dim, name)
self.add_parameters(*kernels)

View file

@ -9,17 +9,17 @@ class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.K = _slice_wrapper(instance, instance.K)
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, True)
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, False, True)
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, True, True)
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, False, True)
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, True, True)
instance.psi0 = _slice_wrapper(instance, instance.psi0, False, False)
instance.psi1 = _slice_wrapper(instance, instance.psi1, False, False)
instance.psi2 = _slice_wrapper(instance, instance.psi2, False, False)
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, psi_stat=True)
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, psi_stat_Z=True)
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, psi_stat=True)
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True)
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True)
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True)
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True)
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True)
instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False)
instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False)
instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False)
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True)
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True)
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True)
instance.parameters_changed()
return instance
@ -44,17 +44,6 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
finally:
kern._sliced_X -= 1
return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1
try:
ret = operation(dL_dK, X, X2)
except:
raise
finally:
kern._sliced_X -= 1
return ret
elif psi_stat:
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
@ -77,6 +66,17 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
finally:
kern._sliced_X -= 1
return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1
try:
ret = operation(dL_dK, X, X2)
except:
raise
finally:
kern._sliced_X -= 1
return ret
else:
if diag:
def x_slice_wrapper(X, *args, **kw):

View file

@ -18,6 +18,7 @@ class Prod(CombinationKernel):
"""
def __init__(self, kernels, name='prod'):
assert len(kernels) == 2, 'only implemented for two kernels as of yet'
super(Prod, self).__init__(kernels, name)
@Cache_this(limit=2, force_kwargs=['which_parts'])
@ -37,26 +38,28 @@ class Prod(CombinationKernel):
which_parts = self.parts
return reduce(np.multiply, (p.Kdiag(X) for p in which_parts))
def update_gradients_full(self, dL_dK, X):
def update_gradients_full(self, dL_dK, X, X2=None):
for k1,k2 in itertools.combinations(self.parts, 2):
k1._sliced_X = k1._sliced_X2 = k2._sliced_X = k2._sliced_X2 = True
k1.update_gradients_full(dL_dK*k2.K(X, X))
self.k2.update_gradients_full(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2])
k1.update_gradients_full(dL_dK*k2.K(X, X2), X, X2)
k2.update_gradients_full(dL_dK*k1.K(X, X2), X, X2)
def update_gradients_diag(self, dL_dKdiag, X):
for k1,k2 in itertools.combinations(self.parts, 2):
k1.update_gradients_diag(dL_dKdiag*k2.Kdiag(X), X)
k2.update_gradients_diag(dL_dKdiag*k1.Kdiag(X), X)
def gradients_X(self, dL_dK, X, X2=None):
target = np.zeros(X.shape)
if X2 is None:
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2]), X[:,self.slice1], None)
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2], None)
else:
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2], X2[:,self.slice2]), X[:,self.slice1], X2[:,self.slice1])
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1], X2[:,self.slice1]), X[:,self.slice2], X2[:,self.slice2])
for k1,k2 in itertools.combinations(self.parts, 2):
target[:,k1.active_dims] += k1.gradients_X(dL_dK*k2.K(X, X2), X, X2)
target[:,k2.active_dims] += k2.gradients_X(dL_dK*k1.K(X, X2), X, X2)
return target
def gradients_X_diag(self, dL_dKdiag, X):
target = np.zeros(X.shape)
target[:,self.slice1] = self.k1.gradients_X(dL_dKdiag*self.k2.Kdiag(X[:,self.slice2]), X[:,self.slice1])
target[:,self.slice2] += self.k2.gradients_X(dL_dKdiag*self.k1.Kdiag(X[:,self.slice1]), X[:,self.slice2])
for k1,k2 in itertools.combinations(self.parts, 2):
target[:,k1.active_dims] += k1.gradients_X(dL_dKdiag*k2.Kdiag(X), X)
target[:,k2.active_dims] += k2.gradients_X(dL_dKdiag*k1.Kdiag(X), X)
return target