mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
update sparse_gp_mpi for new interface
This commit is contained in:
parent
4543fc3480
commit
4e83501ea5
2 changed files with 8 additions and 5 deletions
|
|
@ -42,10 +42,10 @@ class SparseGP_MPI(SparseGP):
|
|||
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
|
||||
|
||||
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
|
||||
self.updates = False
|
||||
self.add_parameter(self.X, index=0)
|
||||
self.update_model(False)
|
||||
self.link_parameter(self.X, index=0)
|
||||
if variational_prior is not None:
|
||||
self.add_parameter(variational_prior)
|
||||
self.link_parameter(variational_prior)
|
||||
# self.X.fix()
|
||||
|
||||
self.mpi_comm = mpi_comm
|
||||
|
|
@ -58,7 +58,8 @@ class SparseGP_MPI(SparseGP):
|
|||
self.Y_local = self.Y[N_start:N_end]
|
||||
print 'MPI RANK '+str(self.mpi_comm.rank)+' with the data range '+str(self.N_range)
|
||||
mpi_comm.Bcast(self.param_array, root=0)
|
||||
self.updates = True
|
||||
self.update_model(True)
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
dc = super(SparseGP_MPI, self).__getstate__()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
from posterior import Posterior
|
||||
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs
|
||||
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri
|
||||
from ...util import diag
|
||||
from ...core.parameterization.variational import VariationalPosterior
|
||||
import numpy as np
|
||||
|
|
@ -172,7 +172,9 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
diag.add(Kmm, self.const_jitter)
|
||||
r1 = checkFullRank(Kmm,name='Kmm')
|
||||
Lm = jitchol(Kmm)
|
||||
LmInv = dtrtri(Lm)
|
||||
|
||||
#LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
|
||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
||||
r2 = checkFullRank(Lambda,name='Lambda')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue