epsilon and power_ep now are parameters of update_likelihood.

This commit is contained in:
Ricardo 2013-09-20 13:22:38 +01:00
parent c8fec98071
commit a51af5b8c4
5 changed files with 62 additions and 30 deletions

View file

@ -538,22 +538,16 @@ class Model(Parameterized):
return k.variances
def pseudo_EM(self, epsilon=.1, **kwargs):
def pseudo_EM(self, stop_crit=.1, **kwargs):
"""
TODO: Should this not bein the GP class?
EM - like algorithm for Expectation Propagation and Laplace approximation
kwargs are passed to the optimize function. They can be:
:stop_crit: convergence criterion
:type stop_crit: float
:epsilon: convergence criterion
:max_f_eval: maximum number of function evaluations
:messages: whether to display during optimisation
:param optimzer: whice optimizer to use (defaults to self.preferred optimizer)
:type optimzer: string TODO: valid strings?
"""
..Note: kwargs are passed to update_likelihood and optimize functions. """
assert isinstance(self.likelihood, likelihoods.EP) or isinstance(self.likelihood, likelihoods.EP_Mixed_Noise), "pseudo_EM is only available for EP likelihoods"
ll_change = epsilon + 1.
ll_change = stop_crit + 1.
iteration = 0
last_ll = -np.inf
@ -561,10 +555,25 @@ class Model(Parameterized):
alpha = 0
stop = False
#Handle **kwargs
ep_args = {}
for arg in kwargs.keys():
if arg in ('epsilon','power_ep'):
ep_args[arg] = kwargs[arg]
del kwargs[arg]
while not stop:
last_approximation = self.likelihood.copy()
last_params = self._get_params()
self.update_likelihood_approximation()
if len(ep_args) == 2:
self.update_likelihood_approximation(epsilon=ep_args['epsilon'],power_ep=ep_args['power_ep'])
elif len(ep_args) == 1:
if ep_args.keys()[0] == 'epsilon':
self.update_likelihood_approximation(epsilon=ep_args['epsilon'])
elif ep_args.keys()[0] == 'power_ep':
self.update_likelihood_approximation(power_ep=ep_args['power_ep'])
else:
self.update_likelihood_approximation()
new_ll = self.log_likelihood()
ll_change = new_ll - last_ll
@ -576,7 +585,7 @@ class Model(Parameterized):
else:
self.optimize(**kwargs)
last_ll = self.log_likelihood()
if ll_change < epsilon:
if ll_change < stop_crit:
stop = True
iteration += 1
if stop: