Tidied up laplace warnings

This commit is contained in:
Alan Saul 2014-11-03 13:36:26 +00:00
parent eea6c15802
commit 02c903c4eb

View file

@ -14,6 +14,9 @@ import numpy as np
from ...util.linalg import mdot, jitchol, dpotrs, dtrtrs, dpotri, symmetrify, pdinv from ...util.linalg import mdot, jitchol, dpotrs, dtrtrs, dpotri, symmetrify, pdinv
from posterior import Posterior from posterior import Posterior
import warnings import warnings
def warning_on_one_line(message, category, filename, lineno, file=None, line=None):
return ' %s:%s: %s:%s\n' % (filename, lineno, category.__name__, message)
warnings.formatwarning = warning_on_one_line
from scipy import optimize from scipy import optimize
from . import LatentFunctionInference from . import LatentFunctionInference
@ -29,8 +32,11 @@ class Laplace(LatentFunctionInference):
""" """
self._mode_finding_tolerance = 1e-7 self._mode_finding_tolerance = 1e-7
self._mode_finding_max_iter = 40 self._mode_finding_max_iter = 60
self.bad_fhat = True self.bad_fhat = False
#Store whether it is the first run of the inference so that we can choose whether we need
#to calculate things or reuse old variables
self.first_run = True
self._previous_Ki_fhat = None self._previous_Ki_fhat = None
def inference(self, kern, X, likelihood, Y, Y_metadata=None): def inference(self, kern, X, likelihood, Y, Y_metadata=None):
@ -42,8 +48,9 @@ class Laplace(LatentFunctionInference):
K = kern.K(X) K = kern.K(X)
#Find mode #Find mode
if self.bad_fhat: if self.bad_fhat or self.first_run:
Ki_f_init = np.zeros_like(Y) Ki_f_init = np.zeros_like(Y)
first_run = False
else: else:
Ki_f_init = self._previous_Ki_fhat Ki_f_init = self._previous_Ki_fhat
@ -123,11 +130,11 @@ class Laplace(LatentFunctionInference):
#Warn of bad fits #Warn of bad fits
if difference > self._mode_finding_tolerance: if difference > self._mode_finding_tolerance:
if not self.bad_fhat: if not self.bad_fhat:
warnings.warn("Not perfect f_hat fit difference: {}".format(difference)) warnings.warn("Not perfect mode found (f_hat). difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
self.bad_fhat = True self.bad_fhat = True
elif self.bad_fhat: elif self.bad_fhat:
self.bad_fhat = False self.bad_fhat = False
warnings.warn("f_hat now fine again") warnings.warn("f_hat now fine again. difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
return f, Ki_f return f, Ki_f