mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 02:52:40 +02:00
bug fix for mpi SSGPLVM
This commit is contained in:
parent
4ec8f464e2
commit
c568bad4fb
2 changed files with 24 additions and 2 deletions
|
|
@ -400,6 +400,10 @@ def update_gradients(model, mpi_comm=None):
|
|||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||
KL_div = KL_div_all
|
||||
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
||||
from ...models import SSGPLVM
|
||||
if isinstance(model, SSGPLVM):
|
||||
grad_pi = np.array(model.variational_prior.pi.gradient)
|
||||
mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
|
||||
model._log_marginal_likelihood -= KL_div
|
||||
|
||||
# dL_dthetaL
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue