SVI now implemented without natural natural gradients or batches

This commit is contained in:
Alan Saul 2014-12-22 13:35:56 +00:00
parent b642360ede
commit a8b0d60c3e
4 changed files with 16 additions and 17 deletions

View file

@ -37,7 +37,7 @@ class SVGP(LatentFunctionInference):
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
#KL = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi*S) + 0.5*q_u_mean.dot(Kmmim)
KLs = -0.5*logdetS -0.5*self.num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(self.q_u_mean*Kmmim,0)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum()
dKL_dm = Kmmim
#dKL_dS = 0.5*(Kmmi - Si)
@ -58,13 +58,13 @@ class SVGP(LatentFunctionInference):
#derivatives of expected likelihood
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
Admu = A.T.dot(dF_dmu)
#AdvA = np.einsum('ijk,jl->ilk', Adv, A)
#AdvA = np.einsum('ijk,jl->ilk', Adv, A)
#AdvA = np.dot(A.T, Adv).swapaxes(0,1)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(self.num_classes)])
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(self.num_inducing)[:,:,None])
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dm = Admu
dF_dS = AdvA
@ -74,10 +74,7 @@ class SVGP(LatentFunctionInference):
log_marginal = F.sum() - KL
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
dL_dchol = 2.*np.dot(dL_dS, L)
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
dL_dchol = choleskies.triang_to_flat(dL_dchol)
return Posterior(mean=q_u_mean, cov=S, K=Kmm), log_marginal, {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv, 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}