product kernel and combination kernel updates

This commit is contained in:
Max Zwiessele 2014-03-13 11:01:48 +00:00
parent 5fb0acbdb4
commit e9c96632ba
4 changed files with 58 additions and 54 deletions

View file

@ -3,7 +3,6 @@
import numpy as np import numpy as np
import itertools import itertools
from ...core.parameterization import Parameterized
from ...util.caching import Cache_this from ...util.caching import Cache_this
from kern import CombinationKernel from kern import CombinationKernel

View file

@ -156,7 +156,7 @@ class Kern(Parameterized):
other.active_dims += self.input_dim other.active_dims += self.input_dim
return self.prod(other) return self.prod(other)
def prod(self, other, name=None): def prod(self, other, name='prod'):
""" """
Multiply two kernels (either on the same space, or on the tensor Multiply two kernels (either on the same space, or on the tensor
product of the input space). product of the input space).
@ -169,12 +169,12 @@ class Kern(Parameterized):
""" """
assert isinstance(other, Kern), "only kernels can be added to kernels..." assert isinstance(other, Kern), "only kernels can be added to kernels..."
from prod import Prod from prod import Prod
kernels = [] #kernels = []
if isinstance(self, Prod): kernels.extend(self._parameters_) #if isinstance(self, Prod): kernels.extend(self._parameters_)
else: kernels.append(self) #else: kernels.append(self)
if isinstance(other, Prod): kernels.extend(other._parameters_) #if isinstance(other, Prod): kernels.extend(other._parameters_)
else: kernels.append(other) #else: kernels.append(other)
return Prod(self, other, name) return Prod([self, other], name)
def _getstate(self): def _getstate(self):
""" """
@ -195,8 +195,10 @@ class Kern(Parameterized):
class CombinationKernel(Kern): class CombinationKernel(Kern):
def __init__(self, kernels, name): def __init__(self, kernels, name):
assert all([isinstance(k, Kern) for k in kernels]) assert all([isinstance(k, Kern) for k in kernels])
# make sure the active dimensions of all underlying kernels are covered:
ma = reduce(lambda a,b: max(a, max(b)), (x.active_dims for x in kernels), 0) ma = reduce(lambda a,b: max(a, max(b)), (x.active_dims for x in kernels), 0)
input_dim = np.r_[0:ma+1] input_dim = np.r_[0:ma+1]
# initialize the kernel with the full input_dim
super(CombinationKernel, self).__init__(input_dim, name) super(CombinationKernel, self).__init__(input_dim, name)
self.add_parameters(*kernels) self.add_parameters(*kernels)

View file

@ -9,17 +9,17 @@ class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw): def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw) instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.K = _slice_wrapper(instance, instance.K) instance.K = _slice_wrapper(instance, instance.K)
instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, True) instance.Kdiag = _slice_wrapper(instance, instance.Kdiag, diag=True)
instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, False, True) instance.update_gradients_full = _slice_wrapper(instance, instance.update_gradients_full, diag=False, derivative=True)
instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, True, True) instance.update_gradients_diag = _slice_wrapper(instance, instance.update_gradients_diag, diag=True, derivative=True)
instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, False, True) instance.gradients_X = _slice_wrapper(instance, instance.gradients_X, diag=False, derivative=True)
instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, True, True) instance.gradients_X_diag = _slice_wrapper(instance, instance.gradients_X_diag, diag=True, derivative=True)
instance.psi0 = _slice_wrapper(instance, instance.psi0, False, False) instance.psi0 = _slice_wrapper(instance, instance.psi0, diag=False, derivative=False)
instance.psi1 = _slice_wrapper(instance, instance.psi1, False, False) instance.psi1 = _slice_wrapper(instance, instance.psi1, diag=False, derivative=False)
instance.psi2 = _slice_wrapper(instance, instance.psi2, False, False) instance.psi2 = _slice_wrapper(instance, instance.psi2, diag=False, derivative=False)
instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, psi_stat=True) instance.update_gradients_expectations = _slice_wrapper(instance, instance.update_gradients_expectations, derivative=True, psi_stat=True)
instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, psi_stat_Z=True) instance.gradients_Z_expectations = _slice_wrapper(instance, instance.gradients_Z_expectations, derivative=True, psi_stat_Z=True)
instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, psi_stat=True) instance.gradients_qX_expectations = _slice_wrapper(instance, instance.gradients_qX_expectations, derivative=True, psi_stat=True)
instance.parameters_changed() instance.parameters_changed()
return instance return instance
@ -44,17 +44,6 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1
try:
ret = operation(dL_dK, X, X2)
except:
raise
finally:
kern._sliced_X -= 1
return ret
elif psi_stat: elif psi_stat:
def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def x_slice_wrapper(dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior Z, variational_posterior = kern._slice_X(Z) if not kern._sliced_X else Z, kern._slice_X(variational_posterior) if not kern._sliced_X else variational_posterior
@ -77,6 +66,17 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False
finally: finally:
kern._sliced_X -= 1 kern._sliced_X -= 1
return ret return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None):
X, X2 = kern._slice_X(X) if not kern._sliced_X else X, kern._slice_X(X2) if X2 is not None and not kern._sliced_X else X2
kern._sliced_X += 1
try:
ret = operation(dL_dK, X, X2)
except:
raise
finally:
kern._sliced_X -= 1
return ret
else: else:
if diag: if diag:
def x_slice_wrapper(X, *args, **kw): def x_slice_wrapper(X, *args, **kw):

View file

@ -18,6 +18,7 @@ class Prod(CombinationKernel):
""" """
def __init__(self, kernels, name='prod'): def __init__(self, kernels, name='prod'):
assert len(kernels) == 2, 'only implemented for two kernels as of yet'
super(Prod, self).__init__(kernels, name) super(Prod, self).__init__(kernels, name)
@Cache_this(limit=2, force_kwargs=['which_parts']) @Cache_this(limit=2, force_kwargs=['which_parts'])
@ -37,26 +38,28 @@ class Prod(CombinationKernel):
which_parts = self.parts which_parts = self.parts
return reduce(np.multiply, (p.Kdiag(X) for p in which_parts)) return reduce(np.multiply, (p.Kdiag(X) for p in which_parts))
def update_gradients_full(self, dL_dK, X): def update_gradients_full(self, dL_dK, X, X2=None):
for k1,k2 in itertools.combinations(self.parts, 2): for k1,k2 in itertools.combinations(self.parts, 2):
k1._sliced_X = k1._sliced_X2 = k2._sliced_X = k2._sliced_X2 = True k1.update_gradients_full(dL_dK*k2.K(X, X2), X, X2)
k1.update_gradients_full(dL_dK*k2.K(X, X)) k2.update_gradients_full(dL_dK*k1.K(X, X2), X, X2)
self.k2.update_gradients_full(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2])
def update_gradients_diag(self, dL_dKdiag, X):
for k1,k2 in itertools.combinations(self.parts, 2):
k1.update_gradients_diag(dL_dKdiag*k2.Kdiag(X), X)
k2.update_gradients_diag(dL_dKdiag*k1.Kdiag(X), X)
def gradients_X(self, dL_dK, X, X2=None): def gradients_X(self, dL_dK, X, X2=None):
target = np.zeros(X.shape) target = np.zeros(X.shape)
if X2 is None: for k1,k2 in itertools.combinations(self.parts, 2):
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2]), X[:,self.slice1], None) target[:,k1.active_dims] += k1.gradients_X(dL_dK*k2.K(X, X2), X, X2)
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2], None) target[:,k2.active_dims] += k2.gradients_X(dL_dK*k1.K(X, X2), X, X2)
else:
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2], X2[:,self.slice2]), X[:,self.slice1], X2[:,self.slice1])
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1], X2[:,self.slice1]), X[:,self.slice2], X2[:,self.slice2])
return target return target
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):
target = np.zeros(X.shape) target = np.zeros(X.shape)
target[:,self.slice1] = self.k1.gradients_X(dL_dKdiag*self.k2.Kdiag(X[:,self.slice2]), X[:,self.slice1]) for k1,k2 in itertools.combinations(self.parts, 2):
target[:,self.slice2] += self.k2.gradients_X(dL_dKdiag*self.k1.Kdiag(X[:,self.slice1]), X[:,self.slice2]) target[:,k1.active_dims] += k1.gradients_X(dL_dKdiag*k2.Kdiag(X), X)
target[:,k2.active_dims] += k2.gradients_X(dL_dKdiag*k1.Kdiag(X), X)
return target return target