mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-02 00:02:38 +02:00
SVI now implemented without natural natural gradients or batches
This commit is contained in:
parent
b642360ede
commit
a8b0d60c3e
4 changed files with 16 additions and 17 deletions
|
|
@ -33,9 +33,9 @@ class SVGP(SparseGP):
|
||||||
|
|
||||||
#?? self.set_data(X, Y)
|
#?? self.set_data(X, Y)
|
||||||
|
|
||||||
self.m = Param('q_u_mean', np.zeros(self.num_inducing))
|
self.m = Param('q_u_mean', np.zeros((self.num_inducing, Y.shape[1])))
|
||||||
chol = choleskies.triang_to_flat(np.eye(self.num_inducing)[:,:,None])
|
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,Y.shape[1])))
|
||||||
self.chol = Param('q_u_chol', chol.flatten())
|
self.chol = Param('q_u_chol', chol)
|
||||||
self.link_parameter(self.chol)
|
self.link_parameter(self.chol)
|
||||||
self.link_parameter(self.m)
|
self.link_parameter(self.m)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -158,9 +158,11 @@ class Posterior(object):
|
||||||
#self._woodbury_inv, _ = dpotrs(self.woodbury_chol, np.eye(self.woodbury_chol.shape[0]), lower=1)
|
#self._woodbury_inv, _ = dpotrs(self.woodbury_chol, np.eye(self.woodbury_chol.shape[0]), lower=1)
|
||||||
symmetrify(self._woodbury_inv)
|
symmetrify(self._woodbury_inv)
|
||||||
elif self._covariance is not None:
|
elif self._covariance is not None:
|
||||||
B = self._K - self._covariance
|
B = np.atleast_3d(self._K) - np.atleast_3d(self._covariance)
|
||||||
tmp, _ = dpotrs(self.K_chol, B)
|
self._woodbury_inv = np.empty_like(B)
|
||||||
self._woodbury_inv, _ = dpotrs(self.K_chol, tmp.T)
|
for i in xrange(B.shape[-1]):
|
||||||
|
tmp, _ = dpotrs(self.K_chol, B[:,:,i])
|
||||||
|
self._woodbury_inv[:,:,i], _ = dpotrs(self.K_chol, tmp.T)
|
||||||
return self._woodbury_inv
|
return self._woodbury_inv
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ class SVGP(LatentFunctionInference):
|
||||||
#compute the KL term
|
#compute the KL term
|
||||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||||
#KL = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi*S) + 0.5*q_u_mean.dot(Kmmim)
|
#KL = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi*S) + 0.5*q_u_mean.dot(Kmmim)
|
||||||
KLs = -0.5*logdetS -0.5*self.num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(self.q_u_mean*Kmmim,0)
|
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||||
KL = KLs.sum()
|
KL = KLs.sum()
|
||||||
dKL_dm = Kmmim
|
dKL_dm = Kmmim
|
||||||
#dKL_dS = 0.5*(Kmmi - Si)
|
#dKL_dS = 0.5*(Kmmi - Si)
|
||||||
|
|
@ -58,13 +58,13 @@ class SVGP(LatentFunctionInference):
|
||||||
#derivatives of expected likelihood
|
#derivatives of expected likelihood
|
||||||
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
||||||
Admu = A.T.dot(dF_dmu)
|
Admu = A.T.dot(dF_dmu)
|
||||||
#AdvA = np.einsum('ijk,jl->ilk', Adv, A)
|
#AdvA = np.einsum('ijk,jl->ilk', Adv, A)
|
||||||
#AdvA = np.dot(A.T, Adv).swapaxes(0,1)
|
#AdvA = np.dot(A.T, Adv).swapaxes(0,1)
|
||||||
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(self.num_classes)])
|
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
|
||||||
tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
||||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
||||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
||||||
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(self.num_inducing)[:,:,None])
|
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
|
||||||
dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
||||||
dF_dm = Admu
|
dF_dm = Admu
|
||||||
dF_dS = AdvA
|
dF_dS = AdvA
|
||||||
|
|
@ -74,10 +74,7 @@ class SVGP(LatentFunctionInference):
|
||||||
log_marginal = F.sum() - KL
|
log_marginal = F.sum() - KL
|
||||||
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
||||||
|
|
||||||
dL_dchol = 2.*np.dot(dL_dS, L)
|
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
|
||||||
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
||||||
|
|
||||||
return Posterior(mean=q_u_mean, cov=S, K=Kmm), log_marginal, {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv, 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
return Posterior(mean=q_u_mean, cov=S, K=Kmm), log_marginal, {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv, 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -318,9 +318,9 @@ class Gaussian(Likelihood):
|
||||||
|
|
||||||
def variational_expectations(self, Y, m, v, gh_points=None):
|
def variational_expectations(self, Y, m, v, gh_points=None):
|
||||||
lik_var = float(self.variance)
|
lik_var = float(self.variance)
|
||||||
F = -0.5*np.log(2*np.pi) -0.5*np.log(lik_var) - 0.5*(np.square(Y) + np.square(m) + v - 2*m.dot(Y))/lik_var
|
F = -0.5*np.log(2*np.pi) -0.5*np.log(lik_var) - 0.5*(np.square(Y) + np.square(m) + v - 2*m*Y)/lik_var
|
||||||
dF_dmu = (Y - m)/lik_var
|
dF_dmu = (Y - m)/lik_var
|
||||||
dF_dv = -0.5/lik_var
|
dF_dv = np.ones_like(v)*(-0.5/lik_var)
|
||||||
dF_dlik_var = -0.5/lik_var + 0.5(np.square(Y) + np.square(m) + v - 2*m.dot(Y))/(lik_var**2)
|
dF_dlik_var = np.sum(-0.5/lik_var + 0.5*(np.square(Y) + np.square(m) + v - 2*m*Y)/(lik_var**2))
|
||||||
dF_dtheta = [dF_dlik_var]
|
dF_dtheta = [dF_dlik_var]
|
||||||
return F, dF_dmu, dF_dv, dF_dtheta
|
return F, dF_dmu, dF_dv, dF_dtheta
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue