Fixed laplace seed, added debugging for misc tests

This commit is contained in:
Alan Saul 2015-08-18 17:41:25 +01:00
parent 33d8441ac8
commit 624d65493c
3 changed files with 7 additions and 3 deletions

View file

@ -171,7 +171,9 @@ class Laplace(LatentFunctionInference):
#define the objective function (to be maximised) #define the objective function (to be maximised)
def obj(Ki_f, f): def obj(Ki_f, f):
ll = -0.5*np.sum(np.dot(Ki_f.T, f)) + np.sum(likelihood.logpdf(f, Y, Y_metadata=Y_metadata)) ll = -0.5*np.sum(np.dot(Ki_f.T, f)) + np.sum(likelihood.logpdf(f, Y, Y_metadata=Y_metadata))
print ll
if np.isnan(ll): if np.isnan(ll):
import ipdb; ipdb.set_trace() # XXX BREAKPOINT
return -np.inf return -np.inf
else: else:
return ll return ll

View file

@ -9,8 +9,7 @@ import inspect
from GPy.likelihoods import link_functions from GPy.likelihoods import link_functions
from GPy.core.parameterization import Param from GPy.core.parameterization import Param
from functools import partial from functools import partial
#np.random.seed(300) fixed_seed = 0
#np.random.seed(4)
#np.seterr(divide='raise') #np.seterr(divide='raise')
def dparam_partial(inst_func, *args): def dparam_partial(inst_func, *args):
@ -105,7 +104,7 @@ class TestNoiseModels(object):
Generic model checker Generic model checker
""" """
def setUp(self): def setUp(self):
np.random.seed(0) np.random.seed(fixed_seed)
self.N = 15 self.N = 15
self.D = 3 self.D = 3
self.X = np.random.rand(self.N, self.D)*10 self.X = np.random.rand(self.N, self.D)*10
@ -704,6 +703,7 @@ class LaplaceTests(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
np.random.seed(fixed_seed)
self.N = 15 self.N = 15
self.D = 1 self.D = 1
self.X = np.random.rand(self.N, self.D)*10 self.X = np.random.rand(self.N, self.D)*10

View file

@ -17,6 +17,8 @@ class MiscTests(np.testing.TestCase):
assert np.isinf(np.exp(self._lim_val_exp + 1)) assert np.isinf(np.exp(self._lim_val_exp + 1))
assert np.isfinite(GPy.util.misc.safe_exp(self._lim_val_exp + 1)) assert np.isfinite(GPy.util.misc.safe_exp(self._lim_val_exp + 1))
print w
print len(w)
assert len(w)==1 # should have one overflow warning assert len(w)==1 # should have one overflow warning
def test_safe_exp_lower(self): def test_safe_exp_lower(self):