mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
initial messing with svgp to diagonalize
This commit is contained in:
parent
3341b3c56a
commit
fc07abed20
1 changed files with 57 additions and 76 deletions
|
|
@ -7,115 +7,96 @@ from scipy.linalg.blas import dgemm, dsymm, dtrmm
|
||||||
|
|
||||||
class SVGP(LatentFunctionInference):
|
class SVGP(LatentFunctionInference):
|
||||||
|
|
||||||
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0):
|
def inference(self, q_v_mean, q_v_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0):
|
||||||
|
|
||||||
|
if mean_function is not None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
num_data, _ = Y.shape
|
num_data, _ = Y.shape
|
||||||
num_inducing, num_outputs = q_u_mean.shape
|
num_inducing, num_outputs = q_v_mean.shape
|
||||||
|
|
||||||
#expand cholesky representation
|
#expand cholesky representation
|
||||||
L = choleskies.flat_to_triang(q_u_chol)
|
Lv = choleskies.flat_to_triang(q_v_chol)
|
||||||
|
|
||||||
|
#deal with posterior copvariance
|
||||||
|
Sv = np.zeros((num_outputs, num_inducing, num_inducing))
|
||||||
|
for i in range(num_outputs):
|
||||||
|
Sv[i] = Lv[i].dot(Lv[i].T)
|
||||||
|
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(Lv[i,:,:])))) for i in range(Lv.shape[0])])
|
||||||
|
traceS = np.array([np.sum(np.square(np.diag(Lv[i,:,:]))) for i in range(Lv.shape[0])])
|
||||||
|
|
||||||
S = np.empty((num_outputs, num_inducing, num_inducing))
|
|
||||||
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
|
|
||||||
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
|
||||||
Si = choleskies.multiple_dpotri(L)
|
|
||||||
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
|
|
||||||
|
|
||||||
if np.any(np.isinf(Si)):
|
|
||||||
raise ValueError("Cholesky representation unstable")
|
|
||||||
|
|
||||||
#compute mean function stuff
|
|
||||||
if mean_function is not None:
|
|
||||||
prior_mean_u = mean_function.f(Z)
|
|
||||||
prior_mean_f = mean_function.f(X)
|
|
||||||
else:
|
|
||||||
prior_mean_u = np.zeros((num_inducing, num_outputs))
|
|
||||||
prior_mean_f = np.zeros((num_data, num_outputs))
|
|
||||||
|
|
||||||
#compute kernel related stuff
|
#compute kernel related stuff
|
||||||
Kmm = kern.K(Z)
|
Kmm = kern.K(Z)
|
||||||
Kmn = kern.K(Z, X)
|
Kmn = kern.K(Z, X)
|
||||||
Knn_diag = kern.Kdiag(X)
|
Knn_diag = kern.Kdiag(X)
|
||||||
Lm = linalg.jitchol(Kmm)
|
R = linalg.jitchol(Kmm)
|
||||||
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
|
|
||||||
Kmmi, _ = linalg.dpotri(Lm)
|
|
||||||
|
|
||||||
#compute the marginal means and variances of q(f)
|
#compute the marginal means and variances of q(f)
|
||||||
A, _ = linalg.dpotrs(Lm, Kmn)
|
AT, _ = linalg.dtrtrs(R, Kmn)
|
||||||
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
|
A = AT.T
|
||||||
v = np.empty((num_data, num_outputs))
|
mu = np.dot(A, q_v_mean)
|
||||||
|
var = np.empty((num_data, num_outputs))
|
||||||
for i in range(num_outputs):
|
for i in range(num_outputs):
|
||||||
tmp = dtrmm(1.0,L[i].T, A, lower=0, trans_a=0)
|
tmp = dtrmm(1.0,Lv[i].T, A.T, lower=0, trans_a=0)
|
||||||
v[:,i] = np.sum(np.square(tmp),0)
|
var[:,i] = np.sum(np.square(tmp),0)
|
||||||
v += (Knn_diag - np.sum(A*Kmn,0))[:,None]
|
var += (Knn_diag - np.sum(np.square(A),1))[:,None]
|
||||||
|
|
||||||
#compute the KL term
|
#compute the KL term
|
||||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
KL = -0.5*logdetS.sum() + 0.5*np.sum(np.square(q_v_mean)) + 0.5*traceS.sum()
|
||||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
dL_dmv = q_v_mean*1
|
||||||
KL = KLs.sum()
|
dL_dL = np.zeros_like(Lv)
|
||||||
#gradient of the KL term (assuming zero mean function)
|
for k in range(num_outputs):
|
||||||
dKL_dm = Kmmim.copy()
|
Lii = np.diagonal(Lv[i])
|
||||||
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
|
diag = np.diagonal(dL_dL[i])
|
||||||
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
diag = Lii - 1./Lii # write in place, need numpy 1.9+
|
||||||
|
|
||||||
if mean_function is not None:
|
|
||||||
#adjust KL term for mean function
|
|
||||||
Kmmi_mfZ = np.dot(Kmmi, prior_mean_u)
|
|
||||||
KL += -np.sum(q_u_mean*Kmmi_mfZ)
|
|
||||||
KL += 0.5*np.sum(Kmmi_mfZ*prior_mean_u)
|
|
||||||
|
|
||||||
#adjust gradient for mean fucntion
|
|
||||||
dKL_dm -= Kmmi_mfZ
|
|
||||||
dKL_dKmm += Kmmim.dot(Kmmi_mfZ.T)
|
|
||||||
dKL_dKmm -= 0.5*Kmmi_mfZ.dot(Kmmi_mfZ.T)
|
|
||||||
|
|
||||||
#compute gradients for mean_function
|
|
||||||
dKL_dmfZ = Kmmi_mfZ - Kmmim
|
|
||||||
|
|
||||||
#quadrature for the likelihood
|
#quadrature for the likelihood
|
||||||
F, dF_dmu, dF_dv, dF_dthetaL = likelihood.variational_expectations(Y, mu, v, Y_metadata=Y_metadata)
|
F, dF_dmu, dF_dv, dF_dthetaL = likelihood.variational_expectations(Y, mu, var, Y_metadata=Y_metadata)
|
||||||
|
|
||||||
#rescale the F term if working on a batch
|
#rescale the F term if working on a batch
|
||||||
F, dF_dmu, dF_dv = F*batch_scale, dF_dmu*batch_scale, dF_dv*batch_scale
|
F, dF_dmu, dF_dv = F*batch_scale, dF_dmu*batch_scale, dF_dv*batch_scale
|
||||||
if dF_dthetaL is not None:
|
if dF_dthetaL is not None:
|
||||||
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
||||||
|
|
||||||
#derivatives of expected likelihood, assuming zero mean function
|
#mv
|
||||||
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
|
dL_dmv += A.T.dot(dF_dmu)
|
||||||
Admu = A.dot(dF_dmu)
|
|
||||||
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
|
|
||||||
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
|
|
||||||
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
|
|
||||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
|
|
||||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
|
||||||
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
|
|
||||||
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
|
|
||||||
|
|
||||||
dF_dKmn = Kmmim.dot(dF_dmu.T)
|
#Kfu
|
||||||
for a,b in zip(tmp, Adv):
|
RiTm, _ = linalg.dtrtrs(R, q_v_mean, lower=1, trans=1)
|
||||||
dF_dKmn += np.dot(a.T, b)
|
dL_dKmn = np.zeros((num_inducing, num_data))
|
||||||
|
for i in range(num_outputs):
|
||||||
|
tmp, _ = linalg.dtrtrs(R, np.eye(num_inducing)-Sv[i], trans=1, lower=1)
|
||||||
|
dL_dKmn += -2*np.dot(tmp, A.T*dF_dv[:,i])
|
||||||
|
dL_dKmn += np.dot(RiTm, dF_dmu.T)
|
||||||
|
|
||||||
dF_dm = Admu
|
#L
|
||||||
dF_dS = AdvA
|
for i in range(num_outputs):
|
||||||
|
dL_dL[i] += np.dot(Lv[i].T, A.T).dot(A*dF_dv[:,i][:,None])
|
||||||
|
|
||||||
#adjust gradient to account for mean function
|
#R
|
||||||
if mean_function is not None:
|
dL_dR = np.zeros((num_inducing, num_inducing))
|
||||||
dF_dmfX = dF_dmu.copy()
|
for i in range(num_outputs):
|
||||||
dF_dmfZ = -Admu
|
tmp = np.eye(num_inducing) - Sv[i]
|
||||||
dF_dKmn -= np.dot(Kmmi_mfZ, dF_dmu.T)
|
tmp = np.dot(tmp, A.T)
|
||||||
dF_dKmm += Admu.dot(Kmmi_mfZ.T)
|
tmp = np.dot(tmp, A*dF_dv[:,i][:,None])
|
||||||
|
tmp, _ = linalg.dtrtrs(R, tmp, trans=1, lower=1)
|
||||||
|
dL_dR += 2*tmp.T
|
||||||
|
dL_dR -= A.T.dot(dF_dmu).dot(RiTm.T)
|
||||||
|
|
||||||
|
#backprop dL_dR for dL_dKmm
|
||||||
|
dL_dKmm = choleskies.backprop_gradient(dL_dR, R)
|
||||||
|
|
||||||
|
|
||||||
#sum (gradients of) expected likelihood and KL part
|
#sum (gradients of) expected likelihood and KL part
|
||||||
log_marginal = F.sum() - KL
|
log_marginal = F.sum() - KL
|
||||||
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
|
||||||
|
|
||||||
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
|
dL_dchol = choleskies.triang_to_flat(dL_dL)
|
||||||
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
|
||||||
|
|
||||||
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dmv, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
||||||
grad_dict['dL_dmfX'] = dF_dmfX
|
grad_dict['dL_dmfX'] = dF_dmfX
|
||||||
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
|
||||||
|
q_u_mean = np.dot(R, q_v_mean)
|
||||||
|
return Posterior(mean=q_u_mean, cov=Sv.T, K=Kmm, prior_mean=0.), log_marginal, grad_dict
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue