mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
generalize the interface of mpi
This commit is contained in:
parent
a853d060fb
commit
f29753e9e6
4 changed files with 143 additions and 80 deletions
|
|
@ -379,7 +379,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
if het_noise:
|
||||
assert False, "Not implemented!"
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
Z_grad_all = model.Z.gradient.copy()
|
||||
mpi_comm.Allreduce([kern_grad, MPI.DOUBLE], [kern_grad_all, MPI.DOUBLE])
|
||||
|
|
@ -404,10 +404,10 @@ def update_gradients(model, mpi_comm=None):
|
|||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||
KL_div = KL_div_all
|
||||
[mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.N_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))]
|
||||
from ...models import SSGPLVM
|
||||
if isinstance(model, SSGPLVM):
|
||||
grad_pi = np.array(model.variational_prior.pi.gradient)
|
||||
mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
|
||||
# from ...models import SSGPLVM
|
||||
# if isinstance(model, SSGPLVM):
|
||||
# grad_pi = np.array(model.variational_prior.pi.gradient)
|
||||
# mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE])
|
||||
model._log_marginal_likelihood -= KL_div
|
||||
|
||||
# dL_dthetaL
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue