mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
bug fix for vardtc parallel
This commit is contained in:
parent
3d5e738ce8
commit
f7b1994fd4
1 changed files with 12 additions and 4 deletions
|
|
@ -388,10 +388,18 @@ def update_gradients(model, mpi_comm=None):
|
|||
#gradients w.r.t. Z
|
||||
model.Z.gradient += model.kern.gradients_X(dL_dKmm, model.Z)
|
||||
|
||||
# Update Log-likelihood
|
||||
KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0])
|
||||
# update for the KL divergence
|
||||
model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0])
|
||||
from ...core.parameterization.variational import SpikeAndSlabPrior
|
||||
if isinstance(model.variational_prior, SpikeAndSlabPrior):
|
||||
# Update Log-likelihood
|
||||
KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0])
|
||||
# update for the KL divergence
|
||||
model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0])
|
||||
else:
|
||||
# Update Log-likelihood
|
||||
KL_div = model.variational_prior.KL_divergence(X)
|
||||
# update for the KL divergence
|
||||
model.variational_prior.update_gradients_KL(X)
|
||||
|
||||
|
||||
if mpi_comm != None:
|
||||
KL_div_all = np.array(KL_div)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue