mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
[parallel vardtc] minor adjustments to work with current implementation of
psi stats
This commit is contained in:
parent
c62dd85418
commit
1e006f63b5
1 changed files with 84 additions and 67 deletions
|
|
@ -22,21 +22,21 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
"""
|
"""
|
||||||
const_jitter = 1e-6
|
const_jitter = 1e-6
|
||||||
def __init__(self, batchsize, limit=1):
|
def __init__(self, batchsize, limit=1):
|
||||||
|
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
|
|
||||||
# Cache functions
|
# Cache functions
|
||||||
from ...util.caching import Cacher
|
from ...util.caching import Cacher
|
||||||
self.get_trYYT = Cacher(self._get_trYYT, limit)
|
self.get_trYYT = Cacher(self._get_trYYT, limit)
|
||||||
self.get_YYTfactor = Cacher(self._get_YYTfactor, limit)
|
self.get_YYTfactor = Cacher(self._get_YYTfactor, limit)
|
||||||
|
|
||||||
self.midRes = {}
|
self.midRes = {}
|
||||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||||
|
|
||||||
def set_limit(self, limit):
|
def set_limit(self, limit):
|
||||||
self.get_trYYT.limit = limit
|
self.get_trYYT.limit = limit
|
||||||
self.get_YYTfactor.limit = limit
|
self.get_YYTfactor.limit = limit
|
||||||
|
|
||||||
def _get_trYYT(self, Y):
|
def _get_trYYT(self, Y):
|
||||||
return param_to_array(np.sum(np.square(Y)))
|
return param_to_array(np.sum(np.square(Y)))
|
||||||
|
|
||||||
|
|
@ -51,23 +51,23 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
return param_to_array(Y)
|
return param_to_array(Y)
|
||||||
else:
|
else:
|
||||||
return jitchol(tdot(Y))
|
return jitchol(tdot(Y))
|
||||||
|
|
||||||
def inference_likelihood(self, kern, X, Z, likelihood, Y):
|
def inference_likelihood(self, kern, X, Z, likelihood, Y):
|
||||||
"""
|
"""
|
||||||
The first phase of inference:
|
The first phase of inference:
|
||||||
Compute: log-likelihood, dL_dKmm
|
Compute: log-likelihood, dL_dKmm
|
||||||
|
|
||||||
Cached intermediate results: Kmm, KmmInv,
|
Cached intermediate results: Kmm, KmmInv,
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_inducing = Z.shape[0]
|
num_inducing = Z.shape[0]
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
|
|
||||||
if isinstance(X, VariationalPosterior):
|
if isinstance(X, VariationalPosterior):
|
||||||
uncertain_inputs = True
|
uncertain_inputs = True
|
||||||
else:
|
else:
|
||||||
uncertain_inputs = False
|
uncertain_inputs = False
|
||||||
|
|
||||||
#see whether we've got a different noise variance for each datum
|
#see whether we've got a different noise variance for each datum
|
||||||
beta = 1./np.fmax(likelihood.variance, 1e-6)
|
beta = 1./np.fmax(likelihood.variance, 1e-6)
|
||||||
het_noise = beta.size > 1
|
het_noise = beta.size > 1
|
||||||
|
|
@ -77,19 +77,19 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
||||||
YYT_factor = Y
|
YYT_factor = Y
|
||||||
trYYT = self.get_trYYT(Y)
|
trYYT = self.get_trYYT(Y)
|
||||||
|
|
||||||
psi2_full = np.zeros((num_inducing,num_inducing))
|
psi2_full = np.zeros((num_inducing,num_inducing))
|
||||||
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
||||||
psi0_full = 0
|
psi0_full = 0
|
||||||
YRY_full = 0
|
YRY_full = 0
|
||||||
|
|
||||||
for n_start in xrange(0,num_data,self.batchsize):
|
for n_start in xrange(0,num_data,self.batchsize):
|
||||||
|
|
||||||
n_end = min(self.batchsize+n_start, num_data)
|
n_end = min(self.batchsize+n_start, num_data)
|
||||||
|
|
||||||
Y_slice = YYT_factor[n_start:n_end]
|
Y_slice = YYT_factor[n_start:n_end]
|
||||||
X_slice = X[n_start:n_end]
|
X_slice = X[n_start:n_end]
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
psi0 = kern.psi0(Z, X_slice)
|
psi0 = kern.psi0(Z, X_slice)
|
||||||
psi1 = kern.psi1(Z, X_slice)
|
psi1 = kern.psi1(Z, X_slice)
|
||||||
|
|
@ -98,7 +98,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
psi0 = kern.Kdiag(X_slice)
|
psi0 = kern.Kdiag(X_slice)
|
||||||
psi1 = kern.K(X_slice, Z)
|
psi1 = kern.K(X_slice, Z)
|
||||||
psi2 = None
|
psi2 = None
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
beta_slice = beta[n_start:n_end]
|
beta_slice = beta[n_start:n_end]
|
||||||
psi0_full += (beta_slice*psi0).sum()
|
psi0_full += (beta_slice*psi0).sum()
|
||||||
|
|
@ -106,33 +106,33 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
YRY_full += (beta_slice*np.square(Y_slice).sum(axis=-1)).sum()
|
YRY_full += (beta_slice*np.square(Y_slice).sum(axis=-1)).sum()
|
||||||
else:
|
else:
|
||||||
psi0_full += psi0.sum()
|
psi0_full += psi0.sum()
|
||||||
psi1Y_full += np.dot(Y_slice.T,psi1) # DxM
|
psi1Y_full += np.dot(Y_slice.T,psi1) # DxM
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
if het_noise:
|
if het_noise:
|
||||||
psi2_full += beta_slice*psi2
|
psi2_full += beta_slice*psi2
|
||||||
else:
|
else:
|
||||||
psi2_full += psi2
|
psi2_full += psi2.sum(0)
|
||||||
else:
|
else:
|
||||||
if het_noise:
|
if het_noise:
|
||||||
psi2_full += beta_slice*np.outer(psi1,psi1)
|
psi2_full += beta_slice*np.outer(psi1,psi1)
|
||||||
else:
|
else:
|
||||||
psi2_full += np.outer(psi1,psi1)
|
psi2_full += np.einsum('nm,jk->mk',psi1,psi1)
|
||||||
|
|
||||||
if not het_noise:
|
if not het_noise:
|
||||||
psi0_full *= beta
|
psi0_full *= beta
|
||||||
psi1Y_full *= beta
|
psi1Y_full *= beta
|
||||||
psi2_full *= beta
|
psi2_full *= beta
|
||||||
YRY_full = trYYT*beta
|
YRY_full = trYYT*beta
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute Common Components
|
# Compute Common Components
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
self.psi1Y = psi1Y_full
|
||||||
Kmm = kern.K(Z).copy()
|
Kmm = kern.K(Z).copy()
|
||||||
diag.add(Kmm, self.const_jitter)
|
diag.add(Kmm, self.const_jitter)
|
||||||
Lm = jitchol(Kmm)
|
Lm = jitchol(Kmm)
|
||||||
|
|
||||||
Lambda = Kmm+psi2_full
|
Lambda = Kmm+psi2_full
|
||||||
LL = jitchol(Lambda)
|
LL = jitchol(Lambda)
|
||||||
b,_ = dtrtrs(LL, psi1Y_full.T)
|
b,_ = dtrtrs(LL, psi1Y_full.T)
|
||||||
|
|
@ -140,18 +140,18 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
v,_ = dtrtrs(LL.T,b,lower=False)
|
v,_ = dtrtrs(LL.T,b,lower=False)
|
||||||
vvt = np.einsum('md,od->mo',v,v)
|
vvt = np.einsum('md,od->mo',v,v)
|
||||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||||
|
|
||||||
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
|
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
|
||||||
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
|
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
|
||||||
KmmInvPsi2LLInvT = dtrtrs(Lm,LmInvPsi2LLInvT,trans=True)[0]
|
KmmInvPsi2LLInvT = dtrtrs(Lm,LmInvPsi2LLInvT,trans=True)[0]
|
||||||
KmmInvPsi2P = dtrtrs(LL,KmmInvPsi2LLInvT.T, trans=True)[0].T
|
KmmInvPsi2P = dtrtrs(LL,KmmInvPsi2LLInvT.T, trans=True)[0].T
|
||||||
|
|
||||||
dL_dpsi2R = (output_dim*KmmInvPsi2P - vvt)/2. # dL_dpsi2 with R inside psi2
|
dL_dpsi2R = (output_dim*KmmInvPsi2P - vvt)/2. # dL_dpsi2 with R inside psi2
|
||||||
|
|
||||||
# Cache intermediate results
|
# Cache intermediate results
|
||||||
self.midRes['dL_dpsi2R'] = dL_dpsi2R
|
self.midRes['dL_dpsi2R'] = dL_dpsi2R
|
||||||
self.midRes['v'] = v
|
self.midRes['v'] = v
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute log-likelihood
|
# Compute log-likelihood
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
@ -159,30 +159,33 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
logL_R = -np.log(beta).sum()
|
logL_R = -np.log(beta).sum()
|
||||||
else:
|
else:
|
||||||
logL_R = -num_data*np.log(beta)
|
logL_R = -num_data*np.log(beta)
|
||||||
logL = -(output_dim*(num_data*log_2_pi+logL_R+psi0_full-np.trace(LmInvPsi2LmInvT))+YRY_full-bbt)/2.-output_dim*(-np.log(np.diag(Lm)).sum()+np.log(np.diag(LL)).sum())
|
logL = (
|
||||||
|
-(output_dim*(num_data*log_2_pi+logL_R+psi0_full-np.trace(LmInvPsi2LmInvT))+YRY_full-bbt)/2.
|
||||||
|
-output_dim*(-np.log(np.diag(Lm)).sum()+np.log(np.diag(LL)).sum())
|
||||||
|
)
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute dL_dKmm
|
# Compute dL_dKmm
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
|
dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute the Posterior distribution of inducing points p(u|Y)
|
# Compute the Posterior distribution of inducing points p(u|Y)
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
# phi_u_mean = np.dot(Kmm,v)
|
# phi_u_mean = np.dot(Kmm,v)
|
||||||
# LLInvKmm,_ = dtrtrs(LL,Kmm)
|
# LLInvKmm,_ = dtrtrs(LL,Kmm)
|
||||||
# # phi_u_var = np.einsum('ma,mb->ab',LLInvKmm,LLInvKmm)
|
# # phi_u_var = np.einsum('ma,mb->ab',LLInvKmm,LLInvKmm)
|
||||||
# phi_u_var = Kmm - np.dot(LLInvKmm.T,LLInvKmm)
|
# phi_u_var = Kmm - np.dot(LLInvKmm.T,LLInvKmm)
|
||||||
|
|
||||||
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||||
|
|
||||||
return logL, dL_dKmm, post
|
return logL, dL_dKmm, post
|
||||||
|
|
||||||
def inference_minibatch(self, kern, X, Z, likelihood, Y):
|
def inference_minibatch(self, kern, X, Z, likelihood, Y):
|
||||||
"""
|
"""
|
||||||
The second phase of inference: Computing the derivatives over a minibatch of Y
|
The second phase of inference: Computing the derivatives over a minibatch of Y
|
||||||
Compute: dL_dpsi0, dL_dpsi1, dL_dpsi2, dL_dthetaL
|
Compute: dL_dpsi0, dL_dpsi1, dL_dpsi2, dL_dthetaL
|
||||||
return a flag showing whether it reached the end of Y (isEnd)
|
return a flag showing whether it reached the end of Y (isEnd)
|
||||||
"""
|
"""
|
||||||
|
|
@ -193,14 +196,14 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
uncertain_inputs = True
|
uncertain_inputs = True
|
||||||
else:
|
else:
|
||||||
uncertain_inputs = False
|
uncertain_inputs = False
|
||||||
|
|
||||||
#see whether we've got a different noise variance for each datum
|
#see whether we've got a different noise variance for each datum
|
||||||
beta = 1./np.fmax(likelihood.variance, 1e-6)
|
beta = 1./np.fmax(likelihood.variance, 1e-6)
|
||||||
het_noise = beta.size > 1
|
het_noise = beta.size > 1
|
||||||
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
|
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
|
||||||
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
||||||
YYT_factor = Y
|
YYT_factor = Y
|
||||||
|
|
||||||
n_start = self.batch_pos
|
n_start = self.batch_pos
|
||||||
n_end = min(self.batchsize+n_start, num_data)
|
n_end = min(self.batchsize+n_start, num_data)
|
||||||
if n_end==num_data:
|
if n_end==num_data:
|
||||||
|
|
@ -209,11 +212,11 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
else:
|
else:
|
||||||
isEnd = False
|
isEnd = False
|
||||||
self.batch_pos = n_end
|
self.batch_pos = n_end
|
||||||
|
|
||||||
num_slice = n_end-n_start
|
num_slice = n_end-n_start
|
||||||
Y_slice = YYT_factor[n_start:n_end]
|
Y_slice = YYT_factor[n_start:n_end]
|
||||||
X_slice = X[n_start:n_end]
|
X_slice = X[n_start:n_end]
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
psi0 = kern.psi0(Z, X_slice)
|
psi0 = kern.psi0(Z, X_slice)
|
||||||
psi1 = kern.psi1(Z, X_slice)
|
psi1 = kern.psi1(Z, X_slice)
|
||||||
|
|
@ -222,51 +225,51 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
psi0 = kern.Kdiag(X_slice)
|
psi0 = kern.Kdiag(X_slice)
|
||||||
psi1 = kern.K(X_slice, Z)
|
psi1 = kern.K(X_slice, Z)
|
||||||
psi2 = None
|
psi2 = None
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
beta = beta[n_start] # assuming batchsize==1
|
beta = beta[n_start] # assuming batchsize==1
|
||||||
|
|
||||||
betaY = beta*Y_slice
|
betaY = beta*Y_slice
|
||||||
betapsi1 = np.einsum('n,nm->nm',beta,psi1)
|
betapsi1 = np.einsum('n,nm->nm',beta,psi1)
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Load Intermediate Results
|
# Load Intermediate Results
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
dL_dpsi2R = self.midRes['dL_dpsi2R']
|
dL_dpsi2R = self.midRes['dL_dpsi2R']
|
||||||
v = self.midRes['v']
|
v = self.midRes['v']
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute dL_dpsi
|
# Compute dL_dpsi
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
dL_dpsi0 = -0.5 * output_dim * (beta * np.ones((n_end-n_start,)))
|
dL_dpsi0 = -0.5 * output_dim * (beta * np.ones((n_end-n_start,)))
|
||||||
|
|
||||||
dL_dpsi1 = np.dot(betaY,v.T)
|
dL_dpsi1 = np.dot(betaY,v.T)
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
dL_dpsi2 = beta* dL_dpsi2R
|
dL_dpsi2 = beta* dL_dpsi2R
|
||||||
else:
|
else:
|
||||||
dL_dpsi1 += np.dot(betapsi1,dL_dpsi2R)*2.
|
dL_dpsi1 += np.dot(betapsi1,dL_dpsi2R)*2.
|
||||||
dL_dpsi2 = None
|
dL_dpsi2 = None
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute dL_dthetaL
|
# Compute dL_dthetaL
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
psiR = np.einsum('mo,nmo->n',dL_dpsi2R,psi2)
|
psiR = np.einsum('mo,nmo->',dL_dpsi2R,psi2)
|
||||||
else:
|
else:
|
||||||
psiR = np.einsum('nm,no,mo->n',psi1,psi1,dL_dpsi2R)
|
psiR = np.einsum('nm,no,mo->',psi1,psi1,dL_dpsi2R)
|
||||||
|
|
||||||
dL_dthetaL = ((np.square(betaY)).sum(axis=-1) + np.square(beta)*(output_dim*psi0)-output_dim*beta)/2. - np.square(beta)*psiR- (betaY*np.dot(betapsi1,v)).sum(axis=-1)
|
dL_dthetaL = ((np.square(betaY)).sum(axis=-1) + np.square(beta)*(output_dim*psi0)-output_dim*beta)/2. - np.square(beta)*psiR- (betaY*np.dot(betapsi1,v)).sum(axis=-1)
|
||||||
else:
|
else:
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
psiR = np.einsum('mo,mo->',dL_dpsi2R,psi2)
|
psiR = np.einsum('mo,nmo->',dL_dpsi2R,psi2)
|
||||||
else:
|
else:
|
||||||
psiR = np.einsum('nm,no,mo->',psi1,psi1,dL_dpsi2R)
|
psiR = np.einsum('nm,no,mo->',psi1,psi1,dL_dpsi2R)
|
||||||
|
|
||||||
dL_dthetaL = ((np.square(betaY)).sum() + beta*beta*output_dim*(psi0.sum())-num_slice*output_dim*beta)/2. - beta*beta*psiR- (betaY*np.dot(betapsi1,v)).sum()
|
dL_dthetaL = ((np.square(betaY)).sum() + beta*beta*output_dim*(psi0.sum())-num_slice*output_dim*beta)/2. - beta*beta*psiR- (betaY*np.dot(betapsi1,v)).sum()
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
|
|
@ -278,15 +281,15 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
grad_dict = {'dL_dKdiag':dL_dpsi0,
|
grad_dict = {'dL_dKdiag':dL_dpsi0,
|
||||||
'dL_dKnm':dL_dpsi1,
|
'dL_dKnm':dL_dpsi1,
|
||||||
'dL_dthetaL':dL_dthetaL}
|
'dL_dthetaL':dL_dthetaL}
|
||||||
|
|
||||||
return isEnd, (n_start,n_end), grad_dict
|
return isEnd, (n_start,n_end), grad_dict
|
||||||
|
|
||||||
|
|
||||||
def update_gradients(model):
|
def update_gradients(model):
|
||||||
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
||||||
|
|
||||||
het_noise = model.likelihood.variance.size > 1
|
het_noise = model.likelihood.variance.size > 1
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
dL_dthetaL = np.empty((model.Y.shape[0],))
|
dL_dthetaL = np.empty((model.Y.shape[0],))
|
||||||
else:
|
else:
|
||||||
|
|
@ -295,40 +298,54 @@ def update_gradients(model):
|
||||||
#gradients w.r.t. kernel
|
#gradients w.r.t. kernel
|
||||||
model.kern.update_gradients_full(dL_dKmm, model.Z, None)
|
model.kern.update_gradients_full(dL_dKmm, model.Z, None)
|
||||||
kern_grad = model.kern.gradient.copy()
|
kern_grad = model.kern.gradient.copy()
|
||||||
|
|
||||||
#gradients w.r.t. Z
|
#gradients w.r.t. Z
|
||||||
model.Z.gradient = model.kern.gradients_X(dL_dKmm, model.Z)
|
model.Z.gradient = model.kern.gradients_X(dL_dKmm, model.Z)
|
||||||
|
|
||||||
isEnd = False
|
isEnd = False
|
||||||
while not isEnd:
|
while not isEnd:
|
||||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
||||||
if isinstance(model.X, VariationalPosterior):
|
if isinstance(model.X, VariationalPosterior):
|
||||||
X_slice = model.X[n_range[0]:n_range[1]]
|
X_slice = model.X[n_range[0]:n_range[1]]
|
||||||
|
|
||||||
|
dL_dpsi1 = grad_dict['dL_dpsi1']#[None, :]
|
||||||
|
dL_dpsi2 = grad_dict['dL_dpsi2'][None, :, :]
|
||||||
#gradients w.r.t. kernel
|
#gradients w.r.t. kernel
|
||||||
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
model.kern.update_gradients_expectations(variational_posterior=X_slice,Z=model.Z,dL_dpsi0=grad_dict['dL_dpsi0'],dL_dpsi1=dL_dpsi1,dL_dpsi2=dL_dpsi2)
|
||||||
kern_grad += model.kern.gradient
|
kern_grad += model.kern.gradient
|
||||||
|
|
||||||
#gradients w.r.t. Z
|
#gradients w.r.t. Z
|
||||||
model.Z.gradient += model.kern.gradients_Z_expectations(
|
model.Z.gradient += model.kern.gradients_Z_expectations(
|
||||||
dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'], Z=model.Z, variational_posterior=X_slice)
|
dL_dpsi0=grad_dict['dL_dpsi0'],
|
||||||
|
dL_dpsi1=dL_dpsi1,
|
||||||
|
dL_dpsi2=dL_dpsi2,
|
||||||
|
Z=model.Z, variational_posterior=X_slice)
|
||||||
|
|
||||||
#gradients w.r.t. posterior parameters of X
|
#gradients w.r.t. posterior parameters of X
|
||||||
X_grad = model.kern.gradients_qX_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
X_grad = model.kern.gradients_qX_expectations(
|
||||||
model.set_X_gradients(X_slice, X_grad)
|
variational_posterior=X_slice,
|
||||||
|
Z=model.Z,
|
||||||
|
dL_dpsi0=grad_dict['dL_dpsi0'],
|
||||||
|
dL_dpsi1=dL_dpsi1,
|
||||||
|
dL_dpsi2=dL_dpsi2)
|
||||||
|
|
||||||
|
model.X.mean[n_range[0]:n_range[1]].gradient = X_grad[0]
|
||||||
|
model.X.variance[n_range[0]:n_range[1]].gradient = X_grad[1]
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL']
|
dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL']
|
||||||
else:
|
else:
|
||||||
dL_dthetaL += grad_dict['dL_dthetaL']
|
dL_dthetaL += grad_dict['dL_dthetaL']
|
||||||
|
#import ipdb;ipdb.set_trace()
|
||||||
|
model.grad_dict = grad_dict
|
||||||
|
if isinstance(model.X, VariationalPosterior):
|
||||||
|
# Update Log-likelihood
|
||||||
|
model._log_marginal_likelihood -= model.variational_prior.KL_divergence(model.X)
|
||||||
|
# update for the KL divergence
|
||||||
|
model.variational_prior.update_gradients_KL(model.X)
|
||||||
|
|
||||||
# Set the gradients w.r.t. kernel
|
# Set the gradients w.r.t. kernel
|
||||||
model.kern.gradient = kern_grad
|
model.kern.gradient = kern_grad
|
||||||
|
|
||||||
# Update Log-likelihood
|
|
||||||
model._log_marginal_likelihood -= model.variational_prior.KL_divergence(model.X)
|
|
||||||
# update for the KL divergence
|
|
||||||
model.variational_prior.update_gradients_KL(model.X)
|
|
||||||
|
|
||||||
# dL_dthetaL
|
# dL_dthetaL
|
||||||
model.likelihood.update_gradients(dL_dthetaL)
|
model.likelihood.update_gradients(dL_dthetaL)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue