diff --git a/GPy/core/sparse_gp_mpi.py b/GPy/core/sparse_gp_mpi.py index 73a37862..cecbe667 100644 --- a/GPy/core/sparse_gp_mpi.py +++ b/GPy/core/sparse_gp_mpi.py @@ -42,10 +42,10 @@ class SparseGP_MPI(SparseGP): assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!' super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer) - self.updates = False - self.add_parameter(self.X, index=0) + self.update_model(False) + self.link_parameter(self.X, index=0) if variational_prior is not None: - self.add_parameter(variational_prior) + self.link_parameter(variational_prior) # self.X.fix() self.mpi_comm = mpi_comm @@ -58,7 +58,8 @@ class SparseGP_MPI(SparseGP): self.Y_local = self.Y[N_start:N_end] print 'MPI RANK '+str(self.mpi_comm.rank)+' with the data range '+str(self.N_range) mpi_comm.Bcast(self.param_array, root=0) - self.updates = True + self.update_model(True) + def __getstate__(self): dc = super(SparseGP_MPI, self).__getstate__() diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index ab4074f4..a7e2a800 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -2,7 +2,7 @@ # Licensed under the BSD 3-clause license (see LICENSE.txt) from posterior import Posterior -from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs +from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri from ...util import diag from ...core.parameterization.variational import VariationalPosterior import numpy as np @@ -172,7 +172,9 @@ class VarDTC_minibatch(LatentFunctionInference): diag.add(Kmm, self.const_jitter) r1 = checkFullRank(Kmm,name='Kmm') Lm = jitchol(Kmm) + LmInv = dtrtri(Lm) + #LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T) LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right') Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT r2 = checkFullRank(Lambda,name='Lambda')