generalize the interface of mpi

This commit is contained in:
Zhenwen Dai 2014-08-26 18:12:41 +01:00
parent a853d060fb
commit f29753e9e6
4 changed files with 143 additions and 80 deletions

View file

@ -379,7 +379,7 @@ def update_gradients(model, mpi_comm=None):
# Gather the gradients from multiple MPI nodes
if mpi_comm != None:
if het_noise:
assert False, "Not implemented!"
raise "het_noise not implemented!"
kern_grad_all = kern_grad.copy()
Z_grad_all = model.Z.gradient.copy()
mpi_comm.Allreduce([kern_grad, MPI.DOUBLE], [kern_grad_all, MPI.DOUBLE])
@ -404,10 +404,10 @@ def update_gradients(model, mpi_comm=None):
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
KL_div = KL_div_all
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.N_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
from ...models import SSGPLVM
if isinstance(model, SSGPLVM):
grad_pi = np.array(model.variational_prior.pi.gradient)
mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
# from ...models import SSGPLVM
# if isinstance(model, SSGPLVM):
# grad_pi = np.array(model.variational_prior.pi.gradient)
# mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
model._log_marginal_likelihood -= KL_div
# dL_dthetaL