mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
fix the SSGPLVM with MPI
This commit is contained in:
parent
08ed72b2f2
commit
cf33808673
7 changed files with 28 additions and 30 deletions
|
|
@ -45,7 +45,6 @@ class SSGPLVM(SparseGP):
|
|||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
||||
gamma[gamma>1.-1e-9] = 1.-1e-9
|
||||
gamma[gamma<1e-9] = 1e-9
|
||||
gamma[:] = 0.5
|
||||
|
||||
if Z is None:
|
||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||
|
|
@ -72,20 +71,19 @@ class SSGPLVM(SparseGP):
|
|||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)
|
||||
self.add_parameter(self.X, index=0)
|
||||
self.add_parameter(self.variational_prior)
|
||||
|
||||
if mpi_comm != None:
|
||||
from ..util.mpi import divide_data
|
||||
N_start, N_end, N_list = divide_data(Y.shape[0], mpi_comm)
|
||||
self.N_range = (N_start, N_end)
|
||||
self.N_list = np.array(N_list)
|
||||
self.Y_local = self.Y[N_start:N_end]
|
||||
print 'MPI RANK: '+str(self.mpi_comm.rank)+' with datasize: '+str(self.N_range)
|
||||
mpi_comm.Bcast(self.param_array, root=0)
|
||||
|
||||
if self.group_spike:
|
||||
[self.X.gamma[:,i].tie('tieGamma'+str(i)) for i in xrange(self.X.gamma.shape[1])] # Tie columns together
|
||||
|
||||
if mpi_comm != None:
|
||||
from ..util.mpi import divide_data
|
||||
Y_start, Y_end, Y_list = divide_data(Y.shape[0], mpi_comm)
|
||||
self.Y_local = self.Y[Y_start:Y_end]
|
||||
self.X_local = self.X[Y_start:Y_end]
|
||||
self.Y_range = (Y_start, Y_end)
|
||||
self.Y_list = np.array(Y_list)
|
||||
print self.mpi_comm.rank, self.Y_range
|
||||
mpi_comm.Bcast(self.param_array, root=0)
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
X.mean.gradient, X.variance.gradient, X.binary_prob.gradient = X_grad
|
||||
|
|
@ -124,9 +122,10 @@ class SSGPLVM(SparseGP):
|
|||
dc = super(SSGPLVM, self).__getstate__()
|
||||
dc['mpi_comm'] = None
|
||||
if self.mpi_comm != None:
|
||||
del dc['N_range']
|
||||
del dc['N_list']
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
del dc['Y_range']
|
||||
return dc
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue