Sparse EP

This commit is contained in:
Ricardo 2013-01-28 00:16:23 +00:00
parent a8738984b3
commit fad0e07624
7 changed files with 399 additions and 8 deletions

View file

@ -9,7 +9,7 @@ from ..util.plot import gpplot
from .. import kern
class EP:
def __init__(self,covariance,likelihood,Kmn=None,Knn_diag=None,epsilon=1e-3,powerep=[1.,1.]):
def __init__(self,covariance,likelihood,Kmn=None,Knn_diag=None,epsilon=1e-3,power_ep=[1.,1.]):
"""
Expectation Propagation
@ -19,7 +19,7 @@ class EP:
likelihood : Output's likelihood (likelihood class)
kernel : a GPy kernel (kern class)
inducing : Either an array specifying the inducing points location or a sacalar defining their number. None value for using a non-sparse model is used.
powerep : Power-EP parameters (eta,delta) - 2x1 numpy array (floats)
power_ep : Power-EP parameters (eta,delta) - 2x1 numpy array (floats)
epsilon : Convergence criterion, maximum squared difference allowed between mean updates to stop iterations (float)
"""
self.likelihood = likelihood
@ -38,7 +38,7 @@ class EP:
assert len(Knn_diag) == self.N, 'Knn_diagonal has size different from N'
self.epsilon = epsilon
self.eta, self.delta = powerep
self.eta, self.delta = power_ep
self.jitter = 1e-12
"""
@ -110,6 +110,7 @@ class Full(EP):
self.Sigma = self.Sigma - Delta_tau/(1.+ Delta_tau*self.Sigma[i,i])*np.dot(si,si.T)
self.mu = np.dot(self.Sigma,self.v_tilde)
self.iterations += 1
print self.tau_tilde[i] #TODO erase me
#Sigma recomptutation with Cholesky decompositon
Sroot_tilde_K = np.sqrt(self.tau_tilde)[:,None]*(self.K)
B = np.eye(self.N) + np.sqrt(self.tau_tilde)[None,:]*Sroot_tilde_K
@ -206,6 +207,7 @@ class DTC(EP):
epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N
self.np1.append(self.tau_tilde.copy())
self.np2.append(self.v_tilde.copy())
return self.tau_tilde[:,None], self.v_tilde[:,None], self.Z_hat[:,None], self.tau_[:,None], self.v_[:,None]
class FITC(EP):
def fit_EP(self):
@ -306,3 +308,4 @@ class FITC(EP):
epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N
self.np1.append(self.tau_tilde.copy())
self.np2.append(self.v_tilde.copy())
return self.tau_tilde[:,None], self.v_tilde[:,None], self.Z_hat[:,None], self.tau_[:,None], self.v_[:,None]