mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
correcting linearCF, mu to go
This commit is contained in:
parent
42474f0044
commit
5051a2fc89
3 changed files with 59 additions and 37 deletions
|
|
@ -144,26 +144,24 @@ class linear(kernpart):
|
|||
# psi2_old = self.ZZ * np.square(self.variances) * self.mu2_S[:, None, None, :]
|
||||
# target += psi2.sum(-1)
|
||||
# slow way of doing it, but right
|
||||
psi2_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
|
||||
for n in range(mu.shape[0]):
|
||||
for m_prime in range(Z.shape[0]):
|
||||
for m in range(Z.shape[0]):
|
||||
tmp = self._Z[m:m + 1] * self.variances
|
||||
tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n:n + 1])))
|
||||
psi2_real[n, m, m_prime] = np.dot(tmp, (
|
||||
self._Z[m_prime:m_prime + 1] * self.variances).T)
|
||||
|
||||
psi2_inner = mdot(self.ZA, self.inner, self.ZA.T)
|
||||
mu2_S = (self._mu[:, None] * self._mu[:, :, None]) + self._S[:, :, None]
|
||||
psi2 = (self.ZA[None, :, None, :] * mu2_S[:, None]).sum(-1)
|
||||
psi2 = (psi2[:, :, None] * self.ZA[None, None]).sum(-1)
|
||||
# psi2_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
|
||||
# for n in range(mu.shape[0]):
|
||||
# for m_prime in range(Z.shape[0]):
|
||||
# for m in range(Z.shape[0]):
|
||||
# tmp = self._Z[m:m + 1] * self.variances
|
||||
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
|
||||
# psi2_real[n, m, m_prime] = np.dot(tmp, (
|
||||
# self._Z[m_prime:m_prime + 1] * self.variances).T)
|
||||
# mu2_S = (self._mu[:, None, :] * self._mu[:, :, None])
|
||||
# mu2_S[:, np.arange(self.D), np.arange(self.D)] += self._S
|
||||
# psi2 = (self.ZA[None, :, None, :] * mu2_S[:, None]).sum(-1)
|
||||
# psi2 = (psi2[:, :, None] * self.ZA[None, None]).sum(-1)
|
||||
# psi2_tensor = np.tensordot(self.ZZ[None, :, :, :] * np.square(self.variances), self.mu2_S[:, None, None, :], ((3), (3))).squeeze().T
|
||||
# import ipdb;ipdb.set_trace()
|
||||
target += psi2_real
|
||||
target += self._psi2
|
||||
|
||||
def dpsi2_dtheta(self, dL_dpsi2, Z, mu, S, target):
|
||||
self._psi_computations(Z, mu, S)
|
||||
tmp = (dL_dpsi2[:, :, :, None] * (2.*self.ZZ * self.mu2_S[:, None, None, :] * self.variances))
|
||||
tmp = dL_dpsi2[:, :, :, None] * (self.ZAinner[:, :, None, :] * (2 * Z)[None, None, :, :])
|
||||
if self.ARD:
|
||||
target += tmp.sum(0).sum(0).sum(0)
|
||||
else:
|
||||
|
|
@ -173,19 +171,34 @@ class linear(kernpart):
|
|||
"""Think N,M,M,Q """
|
||||
self._psi_computations(Z, mu, S)
|
||||
tmp = self.ZZ * np.square(self.variances) # M,M,Q
|
||||
# import ipdb;ipdb.set_trace()
|
||||
dS_old = (dL_dpsi2[:, :, :, None] * tmp).sum(1).sum(1)
|
||||
import ipdb;ipdb.set_trace()
|
||||
target_S += dS_old
|
||||
target_mu += (dL_dpsi2[:, :, :, None] * tmp * 2.*mu[:, None, None, :]).sum(1).sum(1)
|
||||
target_S += (dL_dpsi2[:, :, :, None] * tmp).sum(1).sum(1) * S.shape[0]
|
||||
|
||||
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
|
||||
self._psi_computations(Z, mu, S)
|
||||
# mu2_S = np.sum(self.mu2_S, 0) # Q,
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# prod = (np.eye(Z.shape[0])[:, None, :, None] * (np.dot(self.ZA, self.inner) * self.variances)[None, :, None])
|
||||
# psi2_dZ = prod.swapaxes(0, 1) + prod
|
||||
psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
|
||||
target += psi2_dZ_old # .sum(0).sum(1)
|
||||
# TODO: tensordot would gain some time here
|
||||
# psi2_dZ_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[1]))
|
||||
# for n in range(mu.shape[0]):
|
||||
# for m in range(Z.shape[0]):
|
||||
# tmp = self.variances * (tdot(self._mu[n:n + 1].T) + np.diag(S[n]))
|
||||
# psi2_dZ_real[n, m, :] = np.dot(tmp, (
|
||||
# self._Z[m:m + 1] * self.variances).T).T
|
||||
# tmp = self._Z[m:m + 1] * self.variances
|
||||
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
|
||||
# psi2_dZ_real[n, m, :] = tmp * self.variances
|
||||
# for m_prime in range(Z.shape[0]):
|
||||
# if m == m_prime:
|
||||
# psi2_dZ_real[n, m, :] *= 2
|
||||
# prod = (dL_dpsi2[:, :, :, None] * np.eye(Z.shape[0])[None, :, :, None] * (self.ZAinner * self.variances).swapaxes(0, 1)[:, :, None, :])
|
||||
# psi2_dZ = prod.swapaxes(1, 2) + prod
|
||||
psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
|
||||
target += psi2_dZ.sum(0).sum(0)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
|
||||
# target += (dL_dpsi2[:, :, :, None] * psi2_dZ_real[:, :, None, :]).sum(0).sum(0) * 2 # (self.variances * np.dot(self.inner, self.ZA.T)).sum(1)
|
||||
|
||||
#---------------------------------------#
|
||||
# Precomputations #
|
||||
|
|
@ -203,14 +216,22 @@ class linear(kernpart):
|
|||
|
||||
def _psi_computations(self, Z, mu, S):
|
||||
# here are the "statistics" for psi1 and psi2
|
||||
if not np.all(Z == self._Z):
|
||||
Zv_changed = not (np.array_equal(Z, self._Z) and np.array_equal(self.variances, self._variances))
|
||||
muS_changed = not (np.array_equal(mu, self._mu) and np.array_equal(S, self._S))
|
||||
if Zv_changed:
|
||||
# Z has changed, compute Z specific stuff
|
||||
# self.ZZ = Z[:,None,:]*Z[None,:,:] # M,M,Q
|
||||
self.ZZ = np.empty((Z.shape[0], Z.shape[0], Z.shape[1]), order='F')
|
||||
[tdot(Z[:, i:i + 1], self.ZZ[:, :, i].T) for i in xrange(Z.shape[1])]
|
||||
self._Z = Z.copy()
|
||||
self.ZA = Z * self.variances
|
||||
if not (np.all(mu == self._mu) and np.all(S == self._S)):
|
||||
self._Z = Z.copy()
|
||||
self._variances = self.variances.copy()
|
||||
if muS_changed:
|
||||
self.mu2_S = np.square(mu) + S
|
||||
self.inner = tdot(mu.T) + (np.diag(S.sum(0)))
|
||||
self.inner = (mu[:, None, :] * mu[:, :, None])
|
||||
diag_indices = np.diag_indices(mu.shape[1], 2)
|
||||
self.inner[:, diag_indices[0], diag_indices[1]] += S
|
||||
self._mu, self._S = mu.copy(), S.copy()
|
||||
if Zv_changed or muS_changed:
|
||||
self.ZAinner = np.dot(self.ZA, self.inner).swapaxes(0, 1) # NOTE: self.ZAinner \in [M x N x Q]!
|
||||
self._psi2 = np.dot(self.ZAinner, self.ZA.T)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue