From 7737cecf6db40188ceaf626e2287d380c6705e0e Mon Sep 17 00:00:00 2001 From: Ricardo Andrade Date: Mon, 28 Jan 2013 18:01:55 +0000 Subject: [PATCH] EM algorithm --- GPy/examples/ep_fix.py | 1 + GPy/models/GP.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/GPy/examples/ep_fix.py b/GPy/examples/ep_fix.py index c4e025dd..49ebd5aa 100644 --- a/GPy/examples/ep_fix.py +++ b/GPy/examples/ep_fix.py @@ -35,5 +35,6 @@ print m.checkgrad() m.optimize() #m.em(plot_all=False) # EM algorithm m.plot(samples=3) +m.EM() print(m) diff --git a/GPy/models/GP.py b/GPy/models/GP.py index 3a9f6de8..51da0490 100644 --- a/GPy/models/GP.py +++ b/GPy/models/GP.py @@ -229,6 +229,33 @@ class GP(model): phi = None if not self.EP else self.likelihood.predictive_mean(mu,var) return mu, var, phi + def EM(self,max_f_eval=20,epsilon=.1,plot_all=False): #TODO check this makes sense + """ + Fits sparse_EP and optimizes the hyperparametes iteratively until convergence is achieved. + """ + self.epsilon_em = epsilon + log_likelihood_change = self.epsilon_em + 1. + self.parameters_path = [self._get_params()] + self.approximate_likelihood() + self.site_approximations_path = [[self.ep_approx.tau_tilde,self.ep_approx.v_tilde]] + self.log_likelihood_path = [self.log_likelihood()] + iteration = 0 + while log_likelihood_change > self.epsilon_em: + print 'EM iteration', iteration + self.optimize(max_f_eval = max_f_eval) + log_likelihood_new = self.log_likelihood() + log_likelihood_change = log_likelihood_new - self.log_likelihood_path[-1] + if log_likelihood_change < 0: + print 'log_likelihood decrement' + self._set_params(self.parameters_path[-1]) + self.kern._set_params(self.parameters_path[-1]) + else: + self.approximate_likelihood() + self.log_likelihood_path.append(self.log_likelihood()) + self.parameters_path.append(self._get_params()) + self.site_approximations_path.append([self.ep_approx.tau_tilde,self.ep_approx.v_tilde]) + iteration += 1 + def plot(self,samples=0,plot_limits=None,which_data='all',which_functions='all',resolution=None): """ :param samples: the number of a posteriori samples to plot