correcting linearCF, mu to go

This commit is contained in:
Max Zwiessele 2013-05-02 16:37:47 +01:00
parent 42474f0044
commit 5051a2fc89
3 changed files with 59 additions and 37 deletions

View file

@ -144,26 +144,24 @@ class linear(kernpart):
# psi2_old = self.ZZ * np.square(self.variances) * self.mu2_S[:, None, None, :]
# target += psi2.sum(-1)
# slow way of doing it, but right
psi2_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
for n in range(mu.shape[0]):
for m_prime in range(Z.shape[0]):
for m in range(Z.shape[0]):
tmp = self._Z[m:m + 1] * self.variances
tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n:n + 1])))
psi2_real[n, m, m_prime] = np.dot(tmp, (
self._Z[m_prime:m_prime + 1] * self.variances).T)
psi2_inner = mdot(self.ZA, self.inner, self.ZA.T)
mu2_S = (self._mu[:, None] * self._mu[:, :, None]) + self._S[:, :, None]
psi2 = (self.ZA[None, :, None, :] * mu2_S[:, None]).sum(-1)
psi2 = (psi2[:, :, None] * self.ZA[None, None]).sum(-1)
# psi2_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
# for n in range(mu.shape[0]):
# for m_prime in range(Z.shape[0]):
# for m in range(Z.shape[0]):
# tmp = self._Z[m:m + 1] * self.variances
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
# psi2_real[n, m, m_prime] = np.dot(tmp, (
# self._Z[m_prime:m_prime + 1] * self.variances).T)
# mu2_S = (self._mu[:, None, :] * self._mu[:, :, None])
# mu2_S[:, np.arange(self.D), np.arange(self.D)] += self._S
# psi2 = (self.ZA[None, :, None, :] * mu2_S[:, None]).sum(-1)
# psi2 = (psi2[:, :, None] * self.ZA[None, None]).sum(-1)
# psi2_tensor = np.tensordot(self.ZZ[None, :, :, :] * np.square(self.variances), self.mu2_S[:, None, None, :], ((3), (3))).squeeze().T
# import ipdb;ipdb.set_trace()
target += psi2_real
target += self._psi2
def dpsi2_dtheta(self, dL_dpsi2, Z, mu, S, target):
self._psi_computations(Z, mu, S)
tmp = (dL_dpsi2[:, :, :, None] * (2.*self.ZZ * self.mu2_S[:, None, None, :] * self.variances))
tmp = dL_dpsi2[:, :, :, None] * (self.ZAinner[:, :, None, :] * (2 * Z)[None, None, :, :])
if self.ARD:
target += tmp.sum(0).sum(0).sum(0)
else:
@ -173,19 +171,34 @@ class linear(kernpart):
"""Think N,M,M,Q """
self._psi_computations(Z, mu, S)
tmp = self.ZZ * np.square(self.variances) # M,M,Q
# import ipdb;ipdb.set_trace()
dS_old = (dL_dpsi2[:, :, :, None] * tmp).sum(1).sum(1)
import ipdb;ipdb.set_trace()
target_S += dS_old
target_mu += (dL_dpsi2[:, :, :, None] * tmp * 2.*mu[:, None, None, :]).sum(1).sum(1)
target_S += (dL_dpsi2[:, :, :, None] * tmp).sum(1).sum(1) * S.shape[0]
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
self._psi_computations(Z, mu, S)
# mu2_S = np.sum(self.mu2_S, 0) # Q,
# import ipdb;ipdb.set_trace()
# prod = (np.eye(Z.shape[0])[:, None, :, None] * (np.dot(self.ZA, self.inner) * self.variances)[None, :, None])
# psi2_dZ = prod.swapaxes(0, 1) + prod
psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
target += psi2_dZ_old # .sum(0).sum(1)
# TODO: tensordot would gain some time here
# psi2_dZ_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[1]))
# for n in range(mu.shape[0]):
# for m in range(Z.shape[0]):
# tmp = self.variances * (tdot(self._mu[n:n + 1].T) + np.diag(S[n]))
# psi2_dZ_real[n, m, :] = np.dot(tmp, (
# self._Z[m:m + 1] * self.variances).T).T
# tmp = self._Z[m:m + 1] * self.variances
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
# psi2_dZ_real[n, m, :] = tmp * self.variances
# for m_prime in range(Z.shape[0]):
# if m == m_prime:
# psi2_dZ_real[n, m, :] *= 2
# prod = (dL_dpsi2[:, :, :, None] * np.eye(Z.shape[0])[None, :, :, None] * (self.ZAinner * self.variances).swapaxes(0, 1)[:, :, None, :])
# psi2_dZ = prod.swapaxes(1, 2) + prod
psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
target += psi2_dZ.sum(0).sum(0)
# import ipdb;ipdb.set_trace()
# psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
# target += (dL_dpsi2[:, :, :, None] * psi2_dZ_real[:, :, None, :]).sum(0).sum(0) * 2 # (self.variances * np.dot(self.inner, self.ZA.T)).sum(1)
#---------------------------------------#
# Precomputations #
@ -203,14 +216,22 @@ class linear(kernpart):
def _psi_computations(self, Z, mu, S):
# here are the "statistics" for psi1 and psi2
if not np.all(Z == self._Z):
Zv_changed = not (np.array_equal(Z, self._Z) and np.array_equal(self.variances, self._variances))
muS_changed = not (np.array_equal(mu, self._mu) and np.array_equal(S, self._S))
if Zv_changed:
# Z has changed, compute Z specific stuff
# self.ZZ = Z[:,None,:]*Z[None,:,:] # M,M,Q
self.ZZ = np.empty((Z.shape[0], Z.shape[0], Z.shape[1]), order='F')
[tdot(Z[:, i:i + 1], self.ZZ[:, :, i].T) for i in xrange(Z.shape[1])]
self._Z = Z.copy()
self.ZA = Z * self.variances
if not (np.all(mu == self._mu) and np.all(S == self._S)):
self._Z = Z.copy()
self._variances = self.variances.copy()
if muS_changed:
self.mu2_S = np.square(mu) + S
self.inner = tdot(mu.T) + (np.diag(S.sum(0)))
self.inner = (mu[:, None, :] * mu[:, :, None])
diag_indices = np.diag_indices(mu.shape[1], 2)
self.inner[:, diag_indices[0], diag_indices[1]] += S
self._mu, self._S = mu.copy(), S.copy()
if Zv_changed or muS_changed:
self.ZAinner = np.dot(self.ZA, self.inner).swapaxes(0, 1) # NOTE: self.ZAinner \in [M x N x Q]!
self._psi2 = np.dot(self.ZAinner, self.ZA.T)