fix the SSGPLVM with MPI

This commit is contained in:
Zhenwen Dai 2014-06-25 13:57:25 +01:00
parent 08ed72b2f2
commit cf33808673
7 changed files with 28 additions and 30 deletions

View file

@ -328,7 +328,7 @@ def update_gradients(model, mpi_comm=None):
X = model.X
else:
Y = model.Y_local
X = model.X_local
X = model.X[model.N_range[0]:model.N_range[1]]
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
@ -350,7 +350,7 @@ def update_gradients(model, mpi_comm=None):
if mpi_comm ==None:
X_slice = model.X[n_range[0]:n_range[1]]
else:
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
#gradients w.r.t. kernel
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
@ -396,7 +396,7 @@ def update_gradients(model, mpi_comm=None):
KL_div_all = np.array(KL_div)
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
KL_div = KL_div_all
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.N_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
from ...models import SSGPLVM
if isinstance(model, SSGPLVM):
grad_pi = np.array(model.variational_prior.pi.gradient)