mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
linear psi2 statistics done, all gradients working
This commit is contained in:
parent
914bdc73d8
commit
ecf0dc0680
1 changed files with 6 additions and 12 deletions
|
|
@ -5,7 +5,6 @@
|
|||
from kernpart import kernpart
|
||||
import numpy as np
|
||||
from ..util.linalg import tdot
|
||||
from GPy.util.linalg import mdot
|
||||
|
||||
class linear(kernpart):
|
||||
"""
|
||||
|
|
@ -144,7 +143,7 @@ class linear(kernpart):
|
|||
# psi2_old = self.ZZ * np.square(self.variances) * self.mu2_S[:, None, None, :]
|
||||
# target += psi2.sum(-1)
|
||||
# slow way of doing it, but right
|
||||
# psi2_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
|
||||
# psi2_real = rm np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
|
||||
# for n in range(mu.shape[0]):
|
||||
# for m_prime in range(Z.shape[0]):
|
||||
# for m in range(Z.shape[0]):
|
||||
|
|
@ -171,14 +170,9 @@ class linear(kernpart):
|
|||
"""Think N,M,M,Q """
|
||||
self._psi_computations(Z, mu, S)
|
||||
AZZA = self.ZA.T[:, None, :, None] * self.ZA[None, :, None, :]
|
||||
AZZA += AZZA.swapaxes(1, 2)
|
||||
tmp = self.ZZ * np.square(self.variances) # M,M,Q
|
||||
dS_old = (dL_dpsi2[:, :, :, None] * tmp).sum(1).sum(1)
|
||||
import ipdb;ipdb.set_trace()
|
||||
target_S += dS_old
|
||||
dpsi2_dmu = (dL_dpsi2[:, :, :, None] * np.tensordot(mu, AZZA, ((-1), (0)))).sum(1).sum(1)
|
||||
# twomu = mu[:,None,None,:,None] + mu[:,None,None,None,:]
|
||||
# t = (dL_dpsi2[:, :, :, None, None] * tmp[None, :, :, :, None] * twomu).sum(1).sum(1).sum(1)
|
||||
AZZA = AZZA + AZZA.swapaxes(1, 2)
|
||||
target_S += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1)
|
||||
dpsi2_dmu = (dL_dpsi2[:, :, :, None] * np.tensordot(mu, AZZA, (-1, 0))).sum(1).sum(1)
|
||||
target_mu += dpsi2_dmu
|
||||
|
||||
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
|
||||
|
|
@ -226,8 +220,8 @@ class linear(kernpart):
|
|||
if Zv_changed:
|
||||
# Z has changed, compute Z specific stuff
|
||||
# self.ZZ = Z[:,None,:]*Z[None,:,:] # M,M,Q
|
||||
self.ZZ = np.empty((Z.shape[0], Z.shape[0], Z.shape[1]), order='F')
|
||||
[tdot(Z[:, i:i + 1], self.ZZ[:, :, i].T) for i in xrange(Z.shape[1])]
|
||||
# self.ZZ = np.empty((Z.shape[0], Z.shape[0], Z.shape[1]), order='F')
|
||||
# [tdot(Z[:, i:i + 1], self.ZZ[:, :, i].T) for i in xrange(Z.shape[1])]
|
||||
self.ZA = Z * self.variances
|
||||
self._Z = Z.copy()
|
||||
self._variances = self.variances.copy()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue