From f7b1994fd452123930bc9a754466e060b69034bb Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 20 Oct 2014 13:17:15 +0100 Subject: [PATCH] bug fix for vardtc parallel --- .../var_dtc_parallel.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 43e52e2b..a4803460 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -388,10 +388,18 @@ def update_gradients(model, mpi_comm=None): #gradients w.r.t. Z model.Z.gradient += model.kern.gradients_X(dL_dKmm, model.Z) - # Update Log-likelihood - KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0]) - # update for the KL divergence - model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0]) + from ...core.parameterization.variational import SpikeAndSlabPrior + if isinstance(model.variational_prior, SpikeAndSlabPrior): + # Update Log-likelihood + KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0]) + # update for the KL divergence + model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0]) + else: + # Update Log-likelihood + KL_div = model.variational_prior.KL_divergence(X) + # update for the KL divergence + model.variational_prior.update_gradients_KL(X) + if mpi_comm != None: KL_div_all = np.array(KL_div)