update sparse_gp_mpi for new interface

This commit is contained in:
Zhenwen Dai 2014-09-08 17:22:37 +01:00
parent 4543fc3480
commit 4e83501ea5
2 changed files with 8 additions and 5 deletions

View file

@ -42,10 +42,10 @@ class SparseGP_MPI(SparseGP):
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
self.updates = False
self.add_parameter(self.X, index=0)
self.update_model(False)
self.link_parameter(self.X, index=0)
if variational_prior is not None:
self.add_parameter(variational_prior)
self.link_parameter(variational_prior)
# self.X.fix()
self.mpi_comm = mpi_comm
@ -58,7 +58,8 @@ class SparseGP_MPI(SparseGP):
self.Y_local = self.Y[N_start:N_end]
print 'MPI RANK '+str(self.mpi_comm.rank)+' with the data range '+str(self.N_range)
mpi_comm.Bcast(self.param_array, root=0)
self.updates = True
self.update_model(True)
def __getstate__(self):
dc = super(SparseGP_MPI, self).__getstate__()