bug fix for mpi SSGPLVM

This commit is contained in:
Zhenwen Dai 2014-05-16 10:40:45 +01:00
parent 4ec8f464e2
commit c568bad4fb
2 changed files with 24 additions and 2 deletions

View file

@ -45,9 +45,9 @@ class SSGPLVM(SparseGP):
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
gamma[gamma>=1. - 1e-9] = 1e-9
gamma[gamma>1.-1e-9] = 1.-1e-9
gamma[gamma<1e-9] = 1e-9
#gamma[:] = 0.5
gamma[:] = 0.5
if group_spike:
gamma[:] = gamma.mean(axis=0)
@ -142,3 +142,21 @@ class SSGPLVM(SparseGP):
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
return super(SSGPLVM, self).__setstate__(state)
def _grads(self, x):
if self.mpi_comm != None:
self.mpi_comm.Bcast(x, root=0)
obj_grads = super(SSGPLVM, self)._grads(x)
return obj_grads
def _objective(self, x):
if self.mpi_comm != None:
self.mpi_comm.Bcast(x, root=0)
obj = super(SSGPLVM, self)._objective(x)
return obj
def _objective_grads(self, x):
if self.mpi_comm != None:
self.mpi_comm.Bcast(x, root=0)
obj_f, obj_grads = super(SSGPLVM, self)._objective_grads(x)
return obj_f, obj_grads