SVI now working with minibatches

This commit is contained in:
Alan Saul 2014-12-22 15:40:49 +00:00
parent a8b0d60c3e
commit 1b27337e7c
2 changed files with 45 additions and 40 deletions

View file

@ -5,11 +5,14 @@ import numpy as np
from posterior import Posterior
class SVGP(LatentFunctionInference):
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, Y_metadata=None):
assert Y.shape[1]==1, "multi outputs not implemented"
def __init__(self, KL_scale=1., batch_scale=1.):
self.KL_scale = KL_scale
self.batch_scale = batch_scale
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, Y_metadata=None):
num_inducing = Z.shape[0]
num_data, num_outputs = Y.shape
#expand cholesky representation
L = choleskies.flat_to_triang(q_u_chol)
S = np.einsum('ijk,ljk->ilk', L, L) #L.dot(L.T)
@ -31,29 +34,25 @@ class SVGP(LatentFunctionInference):
#compute the marginal means and variances of q(f)
A = np.dot(Knm, Kmmi)
mu = np.dot(A, q_u_mean)
#v = Knn_diag - np.sum(A*Knm,1) + np.sum(A*A.dot(S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jkl->ikl', A, S),1)
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
#KL = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi*S) + 0.5*q_u_mean.dot(Kmmim)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum()
dKL_dm = Kmmim
#dKL_dS = 0.5*(Kmmi - Si)
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
#dKL_dKmm = 0.5*Kmmi - 0.5*Kmmi.dot(S).dot(Kmmi) - 0.5*Kmmim[:,None]*Kmmim[None,:]
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
#if self.KL_scale:
#scale = 1./np.float64(self.mpi_comm.size)
#KL, dKL_dKmm, dKL_dS, dKL_dm = scale*KL, scale*dKL_dKmm, scale*dKL_dS, scale*dKL_dm
KL_scale = self.KL_scale
batch_scale = self.batch_scale
KL, dKL_dKmm, dKL_dS, dKL_dm = KL_scale*KL, KL_scale*dKL_dKmm, KL_scale*dKL_dS, KL_scale*dKL_dm
#quadrature for the likelihood
F, dF_dmu, dF_dv, dF_dthetaL = likelihood.variational_expectations(Y, mu, v)
#rescale the F term if working on a batch
#F, dF_dmu, dF_dv = F*batch_scale, dF_dmu*batch_scale, dF_dv*batch_scale
F, dF_dmu, dF_dv = F*batch_scale, dF_dmu*batch_scale, dF_dv*batch_scale
#derivatives of expected likelihood
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
@ -69,7 +68,6 @@ class SVGP(LatentFunctionInference):
dF_dm = Admu
dF_dS = AdvA
#sum (gradients of) expected likelihood and KL part
log_marginal = F.sum() - KL
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn