Tidied up laplace warnings

This commit is contained in:
Alan Saul 2014-11-03 13:36:26 +00:00
parent eea6c15802
commit 02c903c4eb

View file

@ -14,6 +14,9 @@ import numpy as np
from ...util.linalg import mdot, jitchol, dpotrs, dtrtrs, dpotri, symmetrify, pdinv
from posterior import Posterior
import warnings
def warning_on_one_line(message, category, filename, lineno, file=None, line=None):
return ' %s:%s: %s:%s\n' % (filename, lineno, category.__name__, message)
warnings.formatwarning = warning_on_one_line
from scipy import optimize
from . import LatentFunctionInference
@ -29,8 +32,11 @@ class Laplace(LatentFunctionInference):
"""
self._mode_finding_tolerance = 1e-7
self._mode_finding_max_iter = 40
self.bad_fhat = True
self._mode_finding_max_iter = 60
self.bad_fhat = False
#Store whether it is the first run of the inference so that we can choose whether we need
#to calculate things or reuse old variables
self.first_run = True
self._previous_Ki_fhat = None
def inference(self, kern, X, likelihood, Y, Y_metadata=None):
@ -42,8 +48,9 @@ class Laplace(LatentFunctionInference):
K = kern.K(X)
#Find mode
if self.bad_fhat:
if self.bad_fhat or self.first_run:
Ki_f_init = np.zeros_like(Y)
first_run = False
else:
Ki_f_init = self._previous_Ki_fhat
@ -123,11 +130,11 @@ class Laplace(LatentFunctionInference):
#Warn of bad fits
if difference > self._mode_finding_tolerance:
if not self.bad_fhat:
warnings.warn("Not perfect f_hat fit difference: {}".format(difference))
warnings.warn("Not perfect mode found (f_hat). difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
self.bad_fhat = True
elif self.bad_fhat:
self.bad_fhat = False
warnings.warn("f_hat now fine again")
warnings.warn("f_hat now fine again. difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
return f, Ki_f