diff --git a/GPy/inference/latent_function_inference/expectation_propagation_dtc.py b/GPy/inference/latent_function_inference/expectation_propagation_dtc.py index 3625a5bf..be1efd3a 100644 --- a/GPy/inference/latent_function_inference/expectation_propagation_dtc.py +++ b/GPy/inference/latent_function_inference/expectation_propagation_dtc.py @@ -5,7 +5,13 @@ from posterior import Posterior log_2_pi = np.log(2*np.pi) class EPDTC(EP): - #def __init__(self, epsilon=1e-6, eta=1., delta=1.): + def __init__(self, epsilon=1e-6, eta=1., delta=1.): + self.epsilon, self.eta, self.delta = epsilon, eta, delta + self.reset() + + def reset(self): + self.old_mutilde, self.old_vtilde = None, None + self._ep_approximation = None def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): num_data, output_dim = X.shape @@ -20,8 +26,10 @@ class EPDTC(EP): KmmiKmn = np.dot(Kmmi,Kmn) K = np.dot(Kmn.T,KmmiKmn) - - mu, Sigma, mu_tilde, tau_tilde, Z_hat = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata) + if self._ep_approximation is None: + mu, Sigma, mu_tilde, tau_tilde, Z_hat = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata) + else: + mu, Sigma, mu_tilde, tau_tilde, Z_hat = self._ep_approximation Wi, LW, LWi, W_logdet = pdinv(K + np.diag(1./tau_tilde))