[studentT] prior by @mathDR

This commit is contained in:
mzwiessele 2016-08-25 08:28:58 +01:00
parent fe61a97905
commit 7b63a195b8
2 changed files with 12 additions and 3 deletions

View file

@ -142,9 +142,9 @@ class LogisticBasisFuncKernel(BasisFuncKernel):
self.centers = np.atleast_2d(centers) self.centers = np.atleast_2d(centers)
self.ARD_slope = ARD_slope self.ARD_slope = ARD_slope
if self.ARD_slope: if self.ARD_slope:
self.slope = Param('slope', slope * np.ones(self.centers.size), Logexp()) self.slope = Param('slope', slope * np.ones(self.centers.size))
else: else:
self.slope = Param('slope', slope, Logexp()) self.slope = Param('slope', slope)
super(LogisticBasisFuncKernel, self).__init__(input_dim, variance, active_dims, ARD, name) super(LogisticBasisFuncKernel, self).__init__(input_dim, variance, active_dims, ARD, name)
self.link_parameter(self.slope) self.link_parameter(self.slope)

View file

@ -18,6 +18,15 @@ class PriorTests(unittest.TestCase):
# setting a StudentT prior on non-negative parameters # setting a StudentT prior on non-negative parameters
# should raise an assertionerror. # should raise an assertionerror.
self.assertRaises(AssertionError, m.rbf.set_prior, studentT) self.assertRaises(AssertionError, m.rbf.set_prior, studentT)
m = GPy.models.SparseGPRegression(X, y)
gaussian = GPy.priors.Gaussian(1, 1)
m.Z.set_prior(studentT)
# setting a Gaussian prior on non-negative parameters
# should raise an assertionerror.
#self.assertRaises(AssertionError, m.Z.set_prior, gaussian)
self.assertTrue(m.checkgrad())
def test_lognormal(self): def test_lognormal(self):
xmin, xmax = 1, 2.5*np.pi xmin, xmax = 1, 2.5*np.pi
@ -87,7 +96,7 @@ class PriorTests(unittest.TestCase):
# setting a Gaussian prior on non-negative parameters # setting a Gaussian prior on non-negative parameters
# should raise an assertionerror. # should raise an assertionerror.
#self.assertRaises(AssertionError, m.Z.set_prior, gaussian) #self.assertRaises(AssertionError, m.Z.set_prior, gaussian)
self.assertTrue(m.checkgrad())
def test_fixed_domain_check(self): def test_fixed_domain_check(self):