From 02c903c4eb732399bf5aaa98abc6e6a2fdd8db8b Mon Sep 17 00:00:00 2001 From: Alan Saul Date: Mon, 3 Nov 2014 13:36:26 +0000 Subject: [PATCH] Tidied up laplace warnings --- .../latent_function_inference/laplace.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/GPy/inference/latent_function_inference/laplace.py b/GPy/inference/latent_function_inference/laplace.py index 01d2f997..2c741b9d 100644 --- a/GPy/inference/latent_function_inference/laplace.py +++ b/GPy/inference/latent_function_inference/laplace.py @@ -14,6 +14,9 @@ import numpy as np from ...util.linalg import mdot, jitchol, dpotrs, dtrtrs, dpotri, symmetrify, pdinv from posterior import Posterior import warnings +def warning_on_one_line(message, category, filename, lineno, file=None, line=None): + return ' %s:%s: %s:%s\n' % (filename, lineno, category.__name__, message) +warnings.formatwarning = warning_on_one_line from scipy import optimize from . import LatentFunctionInference @@ -29,8 +32,11 @@ class Laplace(LatentFunctionInference): """ self._mode_finding_tolerance = 1e-7 - self._mode_finding_max_iter = 40 - self.bad_fhat = True + self._mode_finding_max_iter = 60 + self.bad_fhat = False + #Store whether it is the first run of the inference so that we can choose whether we need + #to calculate things or reuse old variables + self.first_run = True self._previous_Ki_fhat = None def inference(self, kern, X, likelihood, Y, Y_metadata=None): @@ -42,8 +48,9 @@ class Laplace(LatentFunctionInference): K = kern.K(X) #Find mode - if self.bad_fhat: + if self.bad_fhat or self.first_run: Ki_f_init = np.zeros_like(Y) + first_run = False else: Ki_f_init = self._previous_Ki_fhat @@ -123,11 +130,11 @@ class Laplace(LatentFunctionInference): #Warn of bad fits if difference > self._mode_finding_tolerance: if not self.bad_fhat: - warnings.warn("Not perfect f_hat fit difference: {}".format(difference)) + warnings.warn("Not perfect mode found (f_hat). difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter)) self.bad_fhat = True elif self.bad_fhat: self.bad_fhat = False - warnings.warn("f_hat now fine again") + warnings.warn("f_hat now fine again. difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter)) return f, Ki_f