mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 17:52:39 +02:00
fix the SSGPLVM with MPI
This commit is contained in:
parent
08ed72b2f2
commit
cf33808673
7 changed files with 28 additions and 30 deletions
|
|
@ -328,7 +328,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
X = model.X
|
||||
else:
|
||||
Y = model.Y_local
|
||||
X = model.X_local
|
||||
X = model.X[model.N_range[0]:model.N_range[1]]
|
||||
|
||||
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
|
||||
|
||||
|
|
@ -350,7 +350,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
if mpi_comm ==None:
|
||||
X_slice = model.X[n_range[0]:n_range[1]]
|
||||
else:
|
||||
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
|
||||
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
|
||||
|
||||
#gradients w.r.t. kernel
|
||||
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
||||
|
|
@ -396,7 +396,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
KL_div_all = np.array(KL_div)
|
||||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||
KL_div = KL_div_all
|
||||
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
||||
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.N_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
||||
from ...models import SSGPLVM
|
||||
if isinstance(model, SSGPLVM):
|
||||
grad_pi = np.array(model.variational_prior.pi.gradient)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue