tests updated

This commit is contained in:
Max Zwiessele 2013-10-11 16:51:49 +01:00
parent 3132a0362a
commit 4a155dba1e

View file

@ -32,7 +32,8 @@ class Test(unittest.TestCase):
self.bgplvm = BayesianGPLVM(Gaussian(self.Y, variance=self.noise_variance), self.Q, self.X, self.X_variance, kernel=self.kern)
self.bgplvm.ensure_default_constraints()
self.bgplvm.tie_params("noise_variance|white_variance")
self.bgplvm.constrain_fixed("rbf_var")
self.parameter = Parameterized([
Parameterized([
Param('X', self.X),
@ -50,23 +51,24 @@ class Test(unittest.TestCase):
self.parameter['.*variance'].constrain_positive()
self.parameter['.*length'].constrain_positive()
self.parameter.white.tie_to(self.parameter.noise)
self.parameter.rbf_var.constrain_fixed()
def tearDown(self):
pass
def testGrepParamNamesTest(self):
assert(self.bgplvm.grep_param_names('X_\d') == self.parameter.grep_param_names('X_\d'))
assert(self.bgplvm.grep_param_names('X_\d+_1') == self.parameter.grep_param_names('X_\d+_1'))
assert(self.bgplvm.grep_param_names('X_\d_1') == self.parameter.grep_param_names('X_\d_1'))
assert(self.bgplvm.grep_param_names('X_.+_1') == self.parameter.grep_param_names('X_.+_1'))
assert(self.bgplvm.grep_param_names('X_1_1') == self.parameter.grep_param_names('X_1_1'))
assert(self.bgplvm.grep_param_names('X') == self.parameter.grep_param_names('X'))
assert(self.bgplvm.grep_param_names('rbf') == self.parameter.grep_param_names('rbf'))
assert(self.bgplvm.grep_param_names('rbf_l.*_1') == self.parameter.grep_param_names('rbf_l.*_1'))
assert(self.bgplvm.grep_param_names('l') == self.parameter.grep_param_names('l'))
assert(self.bgplvm.grep_param_names('dont_match') == self.parameter.grep_param_names('dont_match'))
assert(self.bgplvm.grep_param_names('.*') == self.parameter.grep_param_names('.*'))
# def testGrepParamNamesTest(self):
# assert(self.bgplvm.grep_param_names('X_\d') == self.parameter.grep_param_names('X_\d'))
# assert(self.bgplvm.grep_param_names('X_\d+_1') == self.parameter.grep_param_names('X_\d+_1'))
# assert(self.bgplvm.grep_param_names('X_\d_1') == self.parameter.grep_param_names('X_\d_1'))
# assert(self.bgplvm.grep_param_names('X_.+_1') == self.parameter.grep_param_names('X_.+_1'))
# assert(self.bgplvm.grep_param_names('X_1_1') == self.parameter.grep_param_names('X_1_1'))
# assert(self.bgplvm.grep_param_names('X') == self.parameter.grep_param_names('X'))
# assert(self.bgplvm.grep_param_names('rbf') == self.parameter.grep_param_names('rbf'))
# assert(self.bgplvm.grep_param_names('rbf_l.*_1') == self.parameter.grep_param_names('rbf_l.*_1'))
# assert(self.bgplvm.grep_param_names('l') == self.parameter.grep_param_names('l'))
# assert(self.bgplvm.grep_param_names('dont_match') == self.parameter.grep_param_names('dont_match'))
# assert(self.bgplvm.grep_param_names('.*') == self.parameter.grep_param_names('.*'))
def testGetParams(self):
assert(numpy.allclose(self.bgplvm._get_params(), self.parameter._get_params()))
@ -86,12 +88,19 @@ class Test(unittest.TestCase):
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
assert(numpy.alltrue(self.parameter.X_variance[1,1] == self.X_variance[1,1]))
import ipdb;ipdb.set_trace()
assert(numpy.alltrue(self.parameter.X_variance[:] == self.X_variance[:]))
assert(numpy.alltrue(self.parameter.X[:,:][:,0:2][:,1] == self.X[:,1]))
assert(numpy.alltrue(self.parameter.X[:,1] == self.X[:,1]))
assert(numpy.alltrue(self.parameter.X_variance[1,1] == self.X_variance[1,1]))
assert(numpy.alltrue(self.parameter.X_variance[:] == self.X_variance[:]))
def testConstraints(self):
self.parameter[''].unconstrain()
self.parameter.X.constrain_positive()
self.parameter.X[:,0].unconstrain_positive()
assert(numpy.alltrue(self.parameter._constraints.indices()[0] == numpy.r_[1:self.N*self.Q:2]))
def testNdarrayFunc(self):
assert(numpy.alltrue(self.parameter.X * self.parameter.X == self.X * self.X))
assert(numpy.alltrue(self.parameter.X * self.parameter.X == self.X * self.X))
@ -105,5 +114,6 @@ if __name__ == "__main__":
'Test.testGetParams',
'Test.testNdarrayFunc',
'Test.testSetParams',
'Test.testConstraints',
]
unittest.main()