fix linear kernel with NxMxM psi2

This commit is contained in:
Zhenwen Dai 2015-09-07 11:43:58 +01:00
parent e906da0309
commit 276330d1d1
2 changed files with 64 additions and 26 deletions

View file

@ -63,19 +63,39 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S):
variance2 = np.square(variance)
common_sum = np.dot(mu,(variance*Z).T)
Z_expect = (np.dot(dL_dpsi2,Z)*Z).sum(axis=0)
dL_dpsi2T = dL_dpsi2+dL_dpsi2.T
common_expect = np.dot(common_sum,np.dot(dL_dpsi2T,Z))
Z2_expect = np.inner(common_sum,dL_dpsi2T)
Z1_expect = np.dot(dL_dpsi2T,Z)
dL_dvar = 2.*S.sum(axis=0)*variance*Z_expect+(common_expect*mu).sum(axis=0)
dL_dmu = common_expect*variance
dL_dS = np.empty(S.shape)
dL_dS[:] = Z_expect*variance2
dL_dZ = variance2*S.sum(axis=0)*Z1_expect+np.dot(Z2_expect.T,variance*mu)
if len(dL_dpsi2.shape)==2:
Z_expect = (np.dot(dL_dpsi2,Z)*Z).sum(axis=0)
dL_dpsi2T = dL_dpsi2+dL_dpsi2.T
common_expect = np.dot(common_sum,np.dot(dL_dpsi2T,Z))
Z2_expect = np.inner(common_sum,dL_dpsi2T)
Z1_expect = np.dot(dL_dpsi2T,Z)
dL_dvar = 2.*S.sum(axis=0)*variance*Z_expect+(common_expect*mu).sum(axis=0)
dL_dmu = common_expect*variance
dL_dS = np.empty(S.shape)
dL_dS[:] = Z_expect*variance2
dL_dZ = variance2*S.sum(axis=0)*Z1_expect+np.dot(Z2_expect.T,variance*mu)
else:
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
dL_dpsi2_ = dL_dpsi2.sum(axis=0)
Z_expect = (np.dot(dL_dpsi2.reshape(N*M,M),Z).reshape(N,M,Q)*Z[None,:,:]).sum(axis=1)
dL_dpsi2T = dL_dpsi2_+dL_dpsi2_.T
dL_dpsi2T_ = dL_dpsi2+np.swapaxes(dL_dpsi2, 1, 2)
common_expect = np.dot(common_sum,np.dot(dL_dpsi2T,Z))
common_expect_ = (common_sum[:,:,None]*np.dot(dL_dpsi2T_.reshape(N*M,M),Z).reshape(N,M,Q)).sum(axis=1)
Z2_expect = (common_sum[:,:,None]*dL_dpsi2T_).sum(axis=1)
Z1_expect = np.dot(dL_dpsi2T_.reshape(N*M,M),Z).reshape(N,M,Q)
dL_dvar = 2.*variance*(S*Z_expect).sum(axis=0)+(common_expect_*mu).sum(axis=0)
dL_dmu = common_expect_*variance
dL_dS = np.empty(S.shape)
dL_dS[:] = variance2* Z_expect
dL_dZ = variance2*(S[:,None,:]*Z1_expect).sum(axis=0)+np.dot(Z2_expect.T,variance*mu)
return dL_dvar, dL_dmu, dL_dS, dL_dZ

View file

@ -452,6 +452,8 @@ class Kernel_Psi_statistics_GradientTests(unittest.TestCase):
self.w2 = np.random.randn(N,M)
self.w3 = np.random.randn(M,M)
self.w3 = self.w3+self.w3.T
self.w3n = np.random.randn(N,M,M)
self.w3n = self.w3n+np.swapaxes(self.w3n, 1,2)
def test_kernels(self):
from GPy.kern import RBF,Linear
@ -463,54 +465,70 @@ class Kernel_Psi_statistics_GradientTests(unittest.TestCase):
self._test_kernel_param(k)
self._test_Z(k)
self._test_qX(k)
self._test_kernel_param(k, psi2n=True)
self._test_Z(k, psi2n=True)
self._test_qX(k, psi2n=True)
def _test_kernel_param(self, kernel, psi2n=False):
def _test_kernel_param(self, kernel):
def f(p):
kernel.param_array[:] = p
psi0 = kernel.psi0(self.Z, self.qX)
psi1 = kernel.psi1(self.Z, self.qX)
psi2 = kernel.psi2(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
if not psi2n:
psi2 = kernel.psi2(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
else:
psi2 = kernel.psi2n(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3n*psi2).sum()
def df(p):
kernel.param_array[:] = p
kernel.update_gradients_expectations(self.w1, self.w2, self.w3, self.Z, self.qX)
kernel.update_gradients_expectations(self.w1, self.w2, self.w3 if not psi2n else self.w3n, self.Z, self.qX)
return kernel.gradient.copy()
from GPy.models import GradientChecker
m = GradientChecker(f, df, kernel.param_array.copy())
self.assertTrue(m.checkgrad())
def _test_Z(self, kernel):
def _test_Z(self, kernel, psi2n=False):
def f(p):
psi0 = kernel.psi0(p, self.qX)
psi1 = kernel.psi1(p, self.qX)
psi2 = kernel.psi2(p, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
if not psi2n:
psi2 = kernel.psi2(p, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
else:
psi2 = kernel.psi2n(p, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3n*psi2).sum()
def df(p):
return kernel.gradients_Z_expectations(self.w1, self.w2, self.w3, p, self.qX)
return kernel.gradients_Z_expectations(self.w1, self.w2, self.w3 if not psi2n else self.w3n, p, self.qX)
from GPy.models import GradientChecker
m = GradientChecker(f, df, self.Z.copy())
self.assertTrue(m.checkgrad())
def _test_qX(self, kernel):
def _test_qX(self, kernel, psi2n=False):
def f(p):
self.qX.param_array[:] = p
self.qX._trigger_params_changed()
psi0 = kernel.psi0(self.Z, self.qX)
psi1 = kernel.psi1(self.Z, self.qX)
psi2 = kernel.psi2(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
if not psi2n:
psi2 = kernel.psi2(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3*psi2).sum()
else:
psi2 = kernel.psi2n(self.Z, self.qX)
return (self.w1*psi0).sum() + (self.w2*psi1).sum() + (self.w3n*psi2).sum()
def df(p):
self.qX.param_array[:] = p
self.qX._trigger_params_changed()
grad = kernel.gradients_qX_expectations(self.w1, self.w2, self.w3, self.Z, self.qX)
grad = kernel.gradients_qX_expectations(self.w1, self.w2, self.w3 if not psi2n else self.w3n, self.Z, self.qX)
self.qX.set_gradients(grad)
return self.qX.gradient.copy()