From ec89c4efc300b7e3e5622c6cd018d6fe7deda55b Mon Sep 17 00:00:00 2001 From: Ricardo Andrade Date: Tue, 29 Jan 2013 16:45:00 +0000 Subject: [PATCH] _compute_GP_variables --- GPy/inference/EP.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/GPy/inference/EP.py b/GPy/inference/EP.py index 5d571888..5c473a8f 100644 --- a/GPy/inference/EP.py +++ b/GPy/inference/EP.py @@ -48,13 +48,13 @@ class EP: self.tau_tilde = np.zeros(self.N) self.v_tilde = np.zeros(self.N) - def restart_EP(self): - """ - Set the EP approximation to initial state - """ - self.tau_tilde = np.zeros(self.N) - self.v_tilde = np.zeros(self.N) - self.mu = np.zeros(self.N) + def _compute_GP_variables(self): + #Variables to be called from GP + mu_tilde = self.v_tilde/self.tau_tilde #When calling EP, this variable is used instead of Y in the GP model + sigma_sum = 1./self.tau_ + 1./self.tau_tilde + mu_diff_2 = (self.v_/self.tau_ - mu_tilde)**2 + Z_ep = np.sum(np.log(self.Z_hat)) + 0.5*np.sum(np.log(sigma_sum)) + 0.5*np.sum(mu_diff_2/sigma_sum) #Normalization constant + return self.tau_tilde[:,None], mu_tilde[:,None], Z_ep class Full(EP): def fit_EP(self): @@ -122,12 +122,7 @@ class Full(EP): self.np1.append(self.tau_tilde.copy()) self.np2.append(self.v_tilde.copy()) - #Variables to be called from GP - mu_tilde = self.v_tilde/self.tau_tilde #When calling EP, this variable is used instead of Y in the GP model - sigma_sum = 1./self.tau_ + 1./self.tau_tilde - mu_diff_2 = (self.v_/self.tau_ - mu_tilde)**2 - Z_ep = np.sum(np.log(self.Z_hat)) + 0.5*np.sum(np.log(sigma_sum)) + 0.5*np.sum(mu_diff_2/sigma_sum) #Normalization constant - return self.tau_tilde[:,None], mu_tilde[:,None], Z_ep + return self._compute_GP_variables() class DTC(EP): def fit_EP(self): @@ -212,7 +207,8 @@ class DTC(EP): epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N self.np1.append(self.tau_tilde.copy()) self.np2.append(self.v_tilde.copy()) - return self.tau_tilde[:,None], self.v_tilde[:,None], self.Z_hat[:,None], self.tau_[:,None], self.v_[:,None] + + return self._compute_GP_variables() class FITC(EP): def fit_EP(self): @@ -313,4 +309,5 @@ class FITC(EP): epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N self.np1.append(self.tau_tilde.copy()) self.np2.append(self.v_tilde.copy()) - return self.tau_tilde[:,None], self.v_tilde[:,None], self.Z_hat[:,None], self.tau_[:,None], self.v_[:,None] + + return self._compute_GP_variables()