mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
EM algorithm
This commit is contained in:
parent
29ec128c9d
commit
7737cecf6d
2 changed files with 28 additions and 0 deletions
|
|
@ -35,5 +35,6 @@ print m.checkgrad()
|
||||||
m.optimize()
|
m.optimize()
|
||||||
#m.em(plot_all=False) # EM algorithm
|
#m.em(plot_all=False) # EM algorithm
|
||||||
m.plot(samples=3)
|
m.plot(samples=3)
|
||||||
|
m.EM()
|
||||||
|
|
||||||
print(m)
|
print(m)
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,33 @@ class GP(model):
|
||||||
phi = None if not self.EP else self.likelihood.predictive_mean(mu,var)
|
phi = None if not self.EP else self.likelihood.predictive_mean(mu,var)
|
||||||
return mu, var, phi
|
return mu, var, phi
|
||||||
|
|
||||||
|
def EM(self,max_f_eval=20,epsilon=.1,plot_all=False): #TODO check this makes sense
|
||||||
|
"""
|
||||||
|
Fits sparse_EP and optimizes the hyperparametes iteratively until convergence is achieved.
|
||||||
|
"""
|
||||||
|
self.epsilon_em = epsilon
|
||||||
|
log_likelihood_change = self.epsilon_em + 1.
|
||||||
|
self.parameters_path = [self._get_params()]
|
||||||
|
self.approximate_likelihood()
|
||||||
|
self.site_approximations_path = [[self.ep_approx.tau_tilde,self.ep_approx.v_tilde]]
|
||||||
|
self.log_likelihood_path = [self.log_likelihood()]
|
||||||
|
iteration = 0
|
||||||
|
while log_likelihood_change > self.epsilon_em:
|
||||||
|
print 'EM iteration', iteration
|
||||||
|
self.optimize(max_f_eval = max_f_eval)
|
||||||
|
log_likelihood_new = self.log_likelihood()
|
||||||
|
log_likelihood_change = log_likelihood_new - self.log_likelihood_path[-1]
|
||||||
|
if log_likelihood_change < 0:
|
||||||
|
print 'log_likelihood decrement'
|
||||||
|
self._set_params(self.parameters_path[-1])
|
||||||
|
self.kern._set_params(self.parameters_path[-1])
|
||||||
|
else:
|
||||||
|
self.approximate_likelihood()
|
||||||
|
self.log_likelihood_path.append(self.log_likelihood())
|
||||||
|
self.parameters_path.append(self._get_params())
|
||||||
|
self.site_approximations_path.append([self.ep_approx.tau_tilde,self.ep_approx.v_tilde])
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
def plot(self,samples=0,plot_limits=None,which_data='all',which_functions='all',resolution=None):
|
def plot(self,samples=0,plot_limits=None,which_data='all',which_functions='all',resolution=None):
|
||||||
"""
|
"""
|
||||||
:param samples: the number of a posteriori samples to plot
|
:param samples: the number of a posteriori samples to plot
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue