flags added

This commit is contained in:
Ricardo 2014-05-15 14:11:11 +01:00
parent 75d5209b98
commit bd7a80dde5

View file

@ -21,6 +21,7 @@ class EP(object):
def reset(self): def reset(self):
self.old_mutilde, self.old_vtilde = None, None self.old_mutilde, self.old_vtilde = None, None
self._ep_approximation = None
def inference(self, kern, X, likelihood, Y, Y_metadata=None, Z=None): def inference(self, kern, X, likelihood, Y, Y_metadata=None, Z=None):
num_data, output_dim = X.shape num_data, output_dim = X.shape
@ -28,7 +29,10 @@ class EP(object):
K = kern.K(X) K = kern.K(X)
mu, Sigma, mu_tilde, tau_tilde, Z_hat = self.expectation_propagation(K, Y, likelihood, Y_metadata) if self._ep_approximation is None:
mu, Sigma, mu_tilde, tau_tilde, Z_hat = self._ep_approximation = self.expectation_propagation(K, Y, likelihood, Y_metadata)
else:
mu, Sigma, mu_tilde, tau_tilde, Z_hat = self._ep_approximation
Wi, LW, LWi, W_logdet = pdinv(K + np.diag(1./tau_tilde)) Wi, LW, LWi, W_logdet = pdinv(K + np.diag(1./tau_tilde))