bug fix for mpi SSGPLVM

This commit is contained in:
Zhenwen Dai 2014-05-16 10:40:45 +01:00
parent 4ec8f464e2
commit c568bad4fb
2 changed files with 24 additions and 2 deletions

View file

@ -400,6 +400,10 @@ def update_gradients(model, mpi_comm=None):
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
KL_div = KL_div_all
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
from ...models import SSGPLVM
if isinstance(model, SSGPLVM):
grad_pi = np.array(model.variational_prior.pi.gradient)
mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
model._log_marginal_likelihood -= KL_div
# dL_dthetaL