mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
clean up parallel framework
This commit is contained in:
parent
9febc73be3
commit
902f8f138b
11 changed files with 135 additions and 550 deletions
|
|
@ -38,7 +38,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
self.midRes = {}
|
||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
|
||||
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the computation of posterior object will be skipped.
|
||||
|
||||
def __getstate__(self):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
|
|
@ -76,6 +76,8 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
|
||||
|
||||
het_noise = beta.size > 1
|
||||
|
||||
assert beta.size == 1
|
||||
|
||||
trYYT = self.get_trYYT(Y)
|
||||
if self.Y_speedup and not het_noise:
|
||||
|
|
@ -83,17 +85,16 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
num_inducing = Z.shape[0]
|
||||
num_data, output_dim = Y.shape
|
||||
if self.batchsize == None:
|
||||
self.batchsize = num_data
|
||||
batchsize = num_data if self.batchsize is None else self.batchsize
|
||||
|
||||
psi2_full = np.zeros((num_inducing,num_inducing))
|
||||
psi2_full = np.zeros((num_inducing,num_inducing)) # MxM
|
||||
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
||||
psi0_full = 0.
|
||||
YRY_full = 0.
|
||||
|
||||
for n_start in xrange(0,num_data,self.batchsize):
|
||||
n_end = min(self.batchsize+n_start, num_data)
|
||||
if (n_end-n_start)==num_data:
|
||||
for n_start in xrange(0,num_data,batchsize):
|
||||
n_end = min(batchsize+n_start, num_data)
|
||||
if batchsize==num_data:
|
||||
Y_slice = Y
|
||||
X_slice = X
|
||||
else:
|
||||
|
|
@ -168,16 +169,18 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
Kmm = kern.K(Z).copy()
|
||||
diag.add(Kmm, self.const_jitter)
|
||||
KmmInv,Lm,LmInv,_ = pdinv(Kmm)
|
||||
Lm = jitchol(Kmm, maxtries=100)
|
||||
|
||||
LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
|
||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
||||
LInv,LL,LLInv,logdet_L = pdinv(Lambda)
|
||||
b = LLInv.dot(LmInv.dot(psi1Y_full.T))
|
||||
LL = jitchol(Lambda, maxtries=100)
|
||||
logdet_L = 2.*np.sum(np.log(np.diag(LL)))
|
||||
b = dtrtrs(LL,dtrtrs(Lm,psi1Y_full.T)[0])[0]
|
||||
bbt = np.square(b).sum()
|
||||
v = LmInv.T.dot(LLInv.T.dot(b))
|
||||
v = dtrtrs(Lm,dtrtrs(LL,b,trans=1)[0],trans=1)[0]
|
||||
|
||||
dL_dpsi2R = LmInv.T.dot(-LLInv.T.dot(tdot(b)+output_dim*np.eye(input_dim)).dot(LLInv)+output_dim*np.eye(input_dim)).dot(LmInv)/2.
|
||||
tmp = -backsub_both_sides(LL, tdot(b)+output_dim*np.eye(input_dim), transpose='left')
|
||||
dL_dpsi2R = backsub_both_sides(Lm, tmp+output_dim*np.eye(input_dim), transpose='left')/2.
|
||||
|
||||
# Cache intermediate results
|
||||
self.midRes['dL_dpsi2R'] = dL_dpsi2R
|
||||
|
|
@ -196,14 +199,15 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
# Compute dL_dKmm
|
||||
#======================================================================
|
||||
|
||||
dL_dKmm = dL_dpsi2R - output_dim*KmmInv.dot(psi2_full).dot(KmmInv)/2.
|
||||
dL_dKmm = dL_dpsi2R - output_dim*backsub_both_sides(Lm, LmInvPsi2LmInvT, transpose='left')/2.
|
||||
|
||||
#======================================================================
|
||||
# Compute the Posterior distribution of inducing points p(u|Y)
|
||||
#======================================================================
|
||||
|
||||
if not self.Y_speedup or het_noise:
|
||||
post = Posterior(woodbury_inv=LmInv.T.dot(np.eye(input_dim)-LInv).dot(LmInv), woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||
wd_inv = backsub_both_sides(Lm, np.eye(input_dim)- backsub_both_sides(LL, np.identity(input_dim), transpose='left'), transpose='left')
|
||||
post = Posterior(woodbury_inv=wd_inv, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||
else:
|
||||
post = None
|
||||
|
||||
|
|
@ -242,7 +246,8 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
YYT_factor = Y
|
||||
|
||||
n_start = self.batch_pos
|
||||
n_end = min(self.batchsize+n_start, num_data)
|
||||
batchsize = num_data if self.batchsize is None else self.batchsize
|
||||
n_end = min(batchsize+n_start, num_data)
|
||||
if n_end==num_data:
|
||||
isEnd = True
|
||||
self.batch_pos = 0
|
||||
|
|
@ -250,8 +255,12 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
isEnd = False
|
||||
self.batch_pos = n_end
|
||||
|
||||
Y_slice = YYT_factor[n_start:n_end]
|
||||
X_slice = X[n_start:n_end]
|
||||
if batchsize==num_data:
|
||||
Y_slice = YYT_factor
|
||||
X_slice =X
|
||||
else:
|
||||
Y_slice = YYT_factor[n_start:n_end]
|
||||
X_slice = X[n_start:n_end]
|
||||
|
||||
if not uncertain_inputs:
|
||||
psi0 = kern.Kdiag(X_slice)
|
||||
|
|
@ -405,3 +414,66 @@ def update_gradients(model, mpi_comm=None):
|
|||
|
||||
# dL_dthetaL
|
||||
model.likelihood.update_gradients(dL_dthetaL)
|
||||
|
||||
def update_gradients_sparsegp(model, mpi_comm=None):
|
||||
if mpi_comm == None:
|
||||
Y = model.Y
|
||||
X = model.X
|
||||
else:
|
||||
Y = model.Y_local
|
||||
X = model.X[model.N_range[0]:model.N_range[1]]
|
||||
|
||||
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
|
||||
|
||||
het_noise = model.likelihood.variance.size > 1
|
||||
|
||||
if het_noise:
|
||||
dL_dthetaL = np.empty((model.Y.shape[0],))
|
||||
else:
|
||||
dL_dthetaL = np.float64(0.)
|
||||
|
||||
kern_grad = model.kern.gradient.copy()
|
||||
kern_grad[:] = 0.
|
||||
model.Z.gradient = 0.
|
||||
|
||||
isEnd = False
|
||||
while not isEnd:
|
||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
|
||||
|
||||
if (n_range[1]-n_range[0])==X.shape[0]:
|
||||
X_slice = X
|
||||
elif mpi_comm ==None:
|
||||
X_slice = model.X[n_range[0]:n_range[1]]
|
||||
else:
|
||||
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
|
||||
|
||||
model.kern.update_gradients_diag(grad_dict['dL_dKdiag'], X_slice)
|
||||
kern_grad += model.kern.gradient
|
||||
model.kern.update_gradients_full(grad_dict['dL_dKnm'], X_slice, model.Z)
|
||||
kern_grad += model.kern.gradient
|
||||
|
||||
model.Z.gradient += model.kern.gradients_X(grad_dict['dL_dKnm'].T, model.Z, X_slice)
|
||||
|
||||
if het_noise:
|
||||
dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL']
|
||||
else:
|
||||
dL_dthetaL += grad_dict['dL_dthetaL']
|
||||
|
||||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
if het_noise:
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
Z_grad_all = model.Z.gradient.copy()
|
||||
mpi_comm.Allreduce([kern_grad, MPI.DOUBLE], [kern_grad_all, MPI.DOUBLE])
|
||||
mpi_comm.Allreduce([model.Z.gradient, MPI.DOUBLE], [Z_grad_all, MPI.DOUBLE])
|
||||
kern_grad = kern_grad_all
|
||||
model.Z.gradient = Z_grad_all
|
||||
|
||||
model.kern.update_gradients_full(dL_dKmm, model.Z, None)
|
||||
model.kern.gradient += kern_grad
|
||||
|
||||
model.Z.gradient += model.kern.gradients_X(dL_dKmm, model.Z)
|
||||
|
||||
# dL_dthetaL
|
||||
model.likelihood.update_gradients(dL_dthetaL)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue