update sparse_gp_mpi for new interface

This commit is contained in:
Zhenwen Dai 2014-09-08 17:22:37 +01:00
parent 4543fc3480
commit 4e83501ea5
2 changed files with 8 additions and 5 deletions

View file

@ -42,10 +42,10 @@ class SparseGP_MPI(SparseGP):
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
self.updates = False
self.add_parameter(self.X, index=0)
self.update_model(False)
self.link_parameter(self.X, index=0)
if variational_prior is not None:
self.add_parameter(variational_prior)
self.link_parameter(variational_prior)
# self.X.fix()
self.mpi_comm = mpi_comm
@ -58,7 +58,8 @@ class SparseGP_MPI(SparseGP):
self.Y_local = self.Y[N_start:N_end]
print 'MPI RANK '+str(self.mpi_comm.rank)+' with the data range '+str(self.N_range)
mpi_comm.Bcast(self.param_array, root=0)
self.updates = True
self.update_model(True)
def __getstate__(self):
dc = super(SparseGP_MPI, self).__getstate__()

View file

@ -2,7 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
from posterior import Posterior
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri
from ...util import diag
from ...core.parameterization.variational import VariationalPosterior
import numpy as np
@ -172,7 +172,9 @@ class VarDTC_minibatch(LatentFunctionInference):
diag.add(Kmm, self.const_jitter)
r1 = checkFullRank(Kmm,name='Kmm')
Lm = jitchol(Kmm)
LmInv = dtrtri(Lm)
#LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
r2 = checkFullRank(Lambda,name='Lambda')