diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 43e52e2b..a4803460 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -388,10 +388,18 @@ def update_gradients(model, mpi_comm=None): #gradients w.r.t. Z model.Z.gradient += model.kern.gradients_X(dL_dKmm, model.Z) - # Update Log-likelihood - KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0]) - # update for the KL divergence - model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0]) + from ...core.parameterization.variational import SpikeAndSlabPrior + if isinstance(model.variational_prior, SpikeAndSlabPrior): + # Update Log-likelihood + KL_div = model.variational_prior.KL_divergence(X, N=model.Y.shape[0]) + # update for the KL divergence + model.variational_prior.update_gradients_KL(X, N=model.Y.shape[0]) + else: + # Update Log-likelihood + KL_div = model.variational_prior.KL_divergence(X) + # update for the KL divergence + model.variational_prior.update_gradients_KL(X) + if mpi_comm != None: KL_div_all = np.array(KL_div)