Update rv_transformation_tests.py

This commit is contained in:
Max Zwiessele 2015-09-12 19:47:11 +01:00
parent 3f769a72d8
commit f111d60ffc

View file

@ -14,8 +14,8 @@ class TestModel(GPy.core.Model):
""" """
A simple GPy model with one parameter. A simple GPy model with one parameter.
""" """
def __init__(self): def __init__(self, name):
GPy.core.Model.__init__(self, 'test_model') GPy.core.Model.__init__(self, name)
theta = GPy.core.Param('theta', 1.) theta = GPy.core.Param('theta', 1.)
self.link_parameter(theta) self.link_parameter(theta)
@ -26,7 +26,7 @@ class TestModel(GPy.core.Model):
class RVTransformationTestCase(unittest.TestCase): class RVTransformationTestCase(unittest.TestCase):
def _test_trans(self, trans): def _test_trans(self, trans):
m = TestModel() m = TestModel(trans.__class__.__name__)
prior = GPy.priors.LogGaussian(.5, 0.1) prior = GPy.priors.LogGaussian(.5, 0.1)
m.theta.set_prior(prior) m.theta.set_prior(prior)
m.theta.unconstrain() m.theta.unconstrain()
@ -56,12 +56,13 @@ class RVTransformationTestCase(unittest.TestCase):
# The following test cannot be very accurate # The following test cannot be very accurate
self.assertTrue(np.linalg.norm(pdf_phi - kde(phi)) / np.linalg.norm(kde(phi)) <= 1e-1) self.assertTrue(np.linalg.norm(pdf_phi - kde(phi)) / np.linalg.norm(kde(phi)) <= 1e-1)
# Check the gradients at a few random points # Check the gradients at a few random points
for i in range(10): for i in range(5):
m.theta = theta_s[i] m.theta = theta_s[i]
self.assertTrue(m.checkgrad(verbose=True)) self.assertTrue(m.checkgrad(verbose=True))
def test_Logexp(self): def test_Logexp(self):
self._test_trans(GPy.constraints.Logexp()) self._test_trans(GPy.constraints.Logexp())
def test_Exponent(self):
self._test_trans(GPy.constraints.Exponent()) self._test_trans(GPy.constraints.Exponent())