mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 11:02:38 +02:00
bug fix for mpi SSGPLVM
This commit is contained in:
parent
4ec8f464e2
commit
c568bad4fb
2 changed files with 24 additions and 2 deletions
|
|
@ -400,6 +400,10 @@ def update_gradients(model, mpi_comm=None):
|
||||||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||||
KL_div = KL_div_all
|
KL_div = KL_div_all
|
||||||
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
||||||
|
from ...models import SSGPLVM
|
||||||
|
if isinstance(model, SSGPLVM):
|
||||||
|
grad_pi = np.array(model.variational_prior.pi.gradient)
|
||||||
|
mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
|
||||||
model._log_marginal_likelihood -= KL_div
|
model._log_marginal_likelihood -= KL_div
|
||||||
|
|
||||||
# dL_dthetaL
|
# dL_dthetaL
|
||||||
|
|
|
||||||
|
|
@ -45,9 +45,9 @@ class SSGPLVM(SparseGP):
|
||||||
|
|
||||||
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
|
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
|
||||||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
||||||
gamma[gamma>=1. - 1e-9] = 1e-9
|
gamma[gamma>1.-1e-9] = 1.-1e-9
|
||||||
gamma[gamma<1e-9] = 1e-9
|
gamma[gamma<1e-9] = 1e-9
|
||||||
#gamma[:] = 0.5
|
gamma[:] = 0.5
|
||||||
|
|
||||||
if group_spike:
|
if group_spike:
|
||||||
gamma[:] = gamma.mean(axis=0)
|
gamma[:] = gamma.mean(axis=0)
|
||||||
|
|
@ -142,3 +142,21 @@ class SSGPLVM(SparseGP):
|
||||||
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
|
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
|
||||||
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
|
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
|
||||||
return super(SSGPLVM, self).__setstate__(state)
|
return super(SSGPLVM, self).__setstate__(state)
|
||||||
|
|
||||||
|
def _grads(self, x):
|
||||||
|
if self.mpi_comm != None:
|
||||||
|
self.mpi_comm.Bcast(x, root=0)
|
||||||
|
obj_grads = super(SSGPLVM, self)._grads(x)
|
||||||
|
return obj_grads
|
||||||
|
|
||||||
|
def _objective(self, x):
|
||||||
|
if self.mpi_comm != None:
|
||||||
|
self.mpi_comm.Bcast(x, root=0)
|
||||||
|
obj = super(SSGPLVM, self)._objective(x)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def _objective_grads(self, x):
|
||||||
|
if self.mpi_comm != None:
|
||||||
|
self.mpi_comm.Bcast(x, root=0)
|
||||||
|
obj_f, obj_grads = super(SSGPLVM, self)._objective_grads(x)
|
||||||
|
return obj_f, obj_grads
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue