mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
tests updated
This commit is contained in:
parent
3132a0362a
commit
4a155dba1e
1 changed files with 23 additions and 13 deletions
|
|
@ -32,7 +32,8 @@ class Test(unittest.TestCase):
|
|||
self.bgplvm = BayesianGPLVM(Gaussian(self.Y, variance=self.noise_variance), self.Q, self.X, self.X_variance, kernel=self.kern)
|
||||
self.bgplvm.ensure_default_constraints()
|
||||
self.bgplvm.tie_params("noise_variance|white_variance")
|
||||
|
||||
self.bgplvm.constrain_fixed("rbf_var")
|
||||
|
||||
self.parameter = Parameterized([
|
||||
Parameterized([
|
||||
Param('X', self.X),
|
||||
|
|
@ -50,23 +51,24 @@ class Test(unittest.TestCase):
|
|||
self.parameter['.*variance'].constrain_positive()
|
||||
self.parameter['.*length'].constrain_positive()
|
||||
self.parameter.white.tie_to(self.parameter.noise)
|
||||
self.parameter.rbf_var.constrain_fixed()
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
def testGrepParamNamesTest(self):
|
||||
assert(self.bgplvm.grep_param_names('X_\d') == self.parameter.grep_param_names('X_\d'))
|
||||
assert(self.bgplvm.grep_param_names('X_\d+_1') == self.parameter.grep_param_names('X_\d+_1'))
|
||||
assert(self.bgplvm.grep_param_names('X_\d_1') == self.parameter.grep_param_names('X_\d_1'))
|
||||
assert(self.bgplvm.grep_param_names('X_.+_1') == self.parameter.grep_param_names('X_.+_1'))
|
||||
assert(self.bgplvm.grep_param_names('X_1_1') == self.parameter.grep_param_names('X_1_1'))
|
||||
assert(self.bgplvm.grep_param_names('X') == self.parameter.grep_param_names('X'))
|
||||
assert(self.bgplvm.grep_param_names('rbf') == self.parameter.grep_param_names('rbf'))
|
||||
assert(self.bgplvm.grep_param_names('rbf_l.*_1') == self.parameter.grep_param_names('rbf_l.*_1'))
|
||||
assert(self.bgplvm.grep_param_names('l') == self.parameter.grep_param_names('l'))
|
||||
assert(self.bgplvm.grep_param_names('dont_match') == self.parameter.grep_param_names('dont_match'))
|
||||
assert(self.bgplvm.grep_param_names('.*') == self.parameter.grep_param_names('.*'))
|
||||
# def testGrepParamNamesTest(self):
|
||||
# assert(self.bgplvm.grep_param_names('X_\d') == self.parameter.grep_param_names('X_\d'))
|
||||
# assert(self.bgplvm.grep_param_names('X_\d+_1') == self.parameter.grep_param_names('X_\d+_1'))
|
||||
# assert(self.bgplvm.grep_param_names('X_\d_1') == self.parameter.grep_param_names('X_\d_1'))
|
||||
# assert(self.bgplvm.grep_param_names('X_.+_1') == self.parameter.grep_param_names('X_.+_1'))
|
||||
# assert(self.bgplvm.grep_param_names('X_1_1') == self.parameter.grep_param_names('X_1_1'))
|
||||
# assert(self.bgplvm.grep_param_names('X') == self.parameter.grep_param_names('X'))
|
||||
# assert(self.bgplvm.grep_param_names('rbf') == self.parameter.grep_param_names('rbf'))
|
||||
# assert(self.bgplvm.grep_param_names('rbf_l.*_1') == self.parameter.grep_param_names('rbf_l.*_1'))
|
||||
# assert(self.bgplvm.grep_param_names('l') == self.parameter.grep_param_names('l'))
|
||||
# assert(self.bgplvm.grep_param_names('dont_match') == self.parameter.grep_param_names('dont_match'))
|
||||
# assert(self.bgplvm.grep_param_names('.*') == self.parameter.grep_param_names('.*'))
|
||||
|
||||
def testGetParams(self):
|
||||
assert(numpy.allclose(self.bgplvm._get_params(), self.parameter._get_params()))
|
||||
|
|
@ -86,12 +88,19 @@ class Test(unittest.TestCase):
|
|||
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
|
||||
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
|
||||
assert(numpy.alltrue(self.parameter.X_variance[1,1] == self.X_variance[1,1]))
|
||||
import ipdb;ipdb.set_trace()
|
||||
assert(numpy.alltrue(self.parameter.X_variance[:] == self.X_variance[:]))
|
||||
assert(numpy.alltrue(self.parameter.X[:,:][:,0:2][:,1] == self.X[:,1]))
|
||||
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
|
||||
assert(numpy.alltrue(self.parameter.X_variance[1,1] == self.X_variance[1,1]))
|
||||
assert(numpy.alltrue(self.parameter.X_variance[:] == self.X_variance[:]))
|
||||
|
||||
def testConstraints(self):
|
||||
self.parameter[''].unconstrain()
|
||||
self.parameter.X.constrain_positive()
|
||||
self.parameter.X[:,0].unconstrain_positive()
|
||||
assert(numpy.alltrue(self.parameter._constraints.indices()[0] == numpy.r_[1:self.N*self.Q:2]))
|
||||
|
||||
def testNdarrayFunc(self):
|
||||
assert(numpy.alltrue(self.parameter.X * self.parameter.X == self.X * self.X))
|
||||
assert(numpy.alltrue(self.parameter.X * self.parameter.X == self.X * self.X))
|
||||
|
|
@ -105,5 +114,6 @@ if __name__ == "__main__":
|
|||
'Test.testGetParams',
|
||||
'Test.testNdarrayFunc',
|
||||
'Test.testSetParams',
|
||||
'Test.testConstraints',
|
||||
]
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue