bug fix for vardtc parallel

This commit is contained in:
Zhenwen Dai 2014-10-20 13:17:15 +01:00
parent 3d5e738ce8
commit f7b1994fd4

View file

@ -388,10 +388,18 @@ def update_gradients(model, mpi_comm=None):
#gradients w.r.t. Z
model.Z.gradient += model.kern.gradients_X(dL_dKmm, model.Z)
# Update Log-likelihood
KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0])
# update for the KL divergence
model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0])
from ...core.parameterization.variational import SpikeAndSlabPrior
if isinstance(model.variational_prior, SpikeAndSlabPrior):
# Update Log-likelihood
KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0])
# update for the KL divergence
model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0])
else:
# Update Log-likelihood
KL_div = model.variational_prior.KL_divergence(X)
# update for the KL divergence
model.variational_prior.update_gradients_KL(X)
if mpi_comm != None:
KL_div_all = np.array(KL_div)