mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
Tidied up laplace warnings
This commit is contained in:
parent
eea6c15802
commit
02c903c4eb
1 changed files with 12 additions and 5 deletions
|
|
@ -14,6 +14,9 @@ import numpy as np
|
|||
from ...util.linalg import mdot, jitchol, dpotrs, dtrtrs, dpotri, symmetrify, pdinv
|
||||
from posterior import Posterior
|
||||
import warnings
|
||||
def warning_on_one_line(message, category, filename, lineno, file=None, line=None):
|
||||
return ' %s:%s: %s:%s\n' % (filename, lineno, category.__name__, message)
|
||||
warnings.formatwarning = warning_on_one_line
|
||||
from scipy import optimize
|
||||
from . import LatentFunctionInference
|
||||
|
||||
|
|
@ -29,8 +32,11 @@ class Laplace(LatentFunctionInference):
|
|||
"""
|
||||
|
||||
self._mode_finding_tolerance = 1e-7
|
||||
self._mode_finding_max_iter = 40
|
||||
self.bad_fhat = True
|
||||
self._mode_finding_max_iter = 60
|
||||
self.bad_fhat = False
|
||||
#Store whether it is the first run of the inference so that we can choose whether we need
|
||||
#to calculate things or reuse old variables
|
||||
self.first_run = True
|
||||
self._previous_Ki_fhat = None
|
||||
|
||||
def inference(self, kern, X, likelihood, Y, Y_metadata=None):
|
||||
|
|
@ -42,8 +48,9 @@ class Laplace(LatentFunctionInference):
|
|||
K = kern.K(X)
|
||||
|
||||
#Find mode
|
||||
if self.bad_fhat:
|
||||
if self.bad_fhat or self.first_run:
|
||||
Ki_f_init = np.zeros_like(Y)
|
||||
first_run = False
|
||||
else:
|
||||
Ki_f_init = self._previous_Ki_fhat
|
||||
|
||||
|
|
@ -123,11 +130,11 @@ class Laplace(LatentFunctionInference):
|
|||
#Warn of bad fits
|
||||
if difference > self._mode_finding_tolerance:
|
||||
if not self.bad_fhat:
|
||||
warnings.warn("Not perfect f_hat fit difference: {}".format(difference))
|
||||
warnings.warn("Not perfect mode found (f_hat). difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
|
||||
self.bad_fhat = True
|
||||
elif self.bad_fhat:
|
||||
self.bad_fhat = False
|
||||
warnings.warn("f_hat now fine again")
|
||||
warnings.warn("f_hat now fine again. difference: {}, iteration: {} out of max {}".format(difference, iteration, self._mode_finding_max_iter))
|
||||
|
||||
return f, Ki_f
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue