mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
fixes now hierarchical, maybe need to be restructured as lookup from constraints
This commit is contained in:
parent
79ba989b31
commit
24b43c490c
3 changed files with 62 additions and 33 deletions
|
|
@ -34,9 +34,9 @@ class ParameterizedTest(unittest.TestCase):
|
|||
self.param = Param('param', np.random.rand(25,2), Logistic(0, 1))
|
||||
|
||||
self.test1 = GPy.core.Parameterized("test model")
|
||||
self.test1.add_parameter(self.white)
|
||||
self.test1.add_parameter(self.rbf, 0)
|
||||
self.test1.add_parameter(self.param)
|
||||
self.test1.kern = self.rbf+self.white
|
||||
self.test1.add_parameter(self.test1.kern)
|
||||
self.test1.add_parameter(self.param, 0)
|
||||
|
||||
x = np.linspace(-2,6,4)[:,None]
|
||||
y = np.sin(x)
|
||||
|
|
@ -45,22 +45,24 @@ class ParameterizedTest(unittest.TestCase):
|
|||
def test_add_parameter(self):
|
||||
self.assertEquals(self.rbf._parent_index_, 0)
|
||||
self.assertEquals(self.white._parent_index_, 1)
|
||||
self.assertEquals(self.param._parent_index_, 0)
|
||||
pass
|
||||
|
||||
def test_fixes(self):
|
||||
self.white.fix(warning=False)
|
||||
self.test1.remove_parameter(self.test1.param)
|
||||
self.test1.remove_parameter(self.param)
|
||||
self.assertTrue(self.test1._has_fixes())
|
||||
from GPy.core.parameterization.transformations import FIXED, UNFIXED
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[UNFIXED,UNFIXED,FIXED])
|
||||
|
||||
self.test1.add_parameter(self.white, 0)
|
||||
self.test1.kern.add_parameter(self.white, 0)
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED,UNFIXED,UNFIXED])
|
||||
self.test1.kern.rbf.fix()
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED]*3)
|
||||
|
||||
def test_remove_parameter(self):
|
||||
from GPy.core.parameterization.transformations import FIXED, UNFIXED, __fixed__, Logexp
|
||||
self.white.fix()
|
||||
self.test1.remove_parameter(self.white)
|
||||
self.test1.kern.remove_parameter(self.white)
|
||||
self.assertIs(self.test1._fixes_,None)
|
||||
|
||||
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||
|
|
@ -81,7 +83,12 @@ class ParameterizedTest(unittest.TestCase):
|
|||
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||
self.assertListEqual(self.test1.constraints[Logexp()].tolist(), [0,1])
|
||||
self.assertListEqual(self.test1.constraints[Logexp()].tolist(), range(self.param.size, self.param.size+self.rbf.size))
|
||||
|
||||
def test_remove_parameter_param_array_grad_array(self):
|
||||
val = self.test1.kern._param_array_.copy()
|
||||
self.test1.kern.remove_parameter(self.white)
|
||||
self.assertListEqual(self.test1.kern._param_array_.tolist(), val[:2].tolist())
|
||||
|
||||
def test_add_parameter_already_in_hirarchy(self):
|
||||
self.assertRaises(HierarchyError, self.test1.add_parameter, self.white._parameters_[0])
|
||||
|
|
@ -91,28 +98,35 @@ class ParameterizedTest(unittest.TestCase):
|
|||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertListEqual(self.rbf.constraints.indices()[0].tolist(), range(2))
|
||||
from GPy.core.parameterization.transformations import Logexp
|
||||
kern = self.rbf+self.white
|
||||
kern = self.test1.kern
|
||||
self.test1.remove_parameter(kern)
|
||||
self.assertListEqual(kern.constraints[Logexp()].tolist(), range(3))
|
||||
|
||||
def test_constraints(self):
|
||||
self.rbf.constrain(GPy.transformations.Square(), False)
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(2))
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [2])
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(self.param.size, self.param.size+self.rbf.size))
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [self.param.size+self.rbf.size])
|
||||
|
||||
self.test1.remove_parameter(self.rbf)
|
||||
self.test1.kern.remove_parameter(self.rbf)
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), [])
|
||||
|
||||
def test_constraints_views(self):
|
||||
self.assertEqual(self.white.constraints._offset, 2)
|
||||
self.assertEqual(self.rbf.constraints._offset, 0)
|
||||
self.assertEqual(self.param.constraints._offset, 3)
|
||||
self.assertEqual(self.white.constraints._offset, self.param.size+self.rbf.size)
|
||||
self.assertEqual(self.rbf.constraints._offset, self.param.size)
|
||||
self.assertEqual(self.param.constraints._offset, 0)
|
||||
|
||||
def test_fixing_randomize(self):
|
||||
self.white.fix(warning=True)
|
||||
val = float(self.test1.white.variance)
|
||||
val = float(self.white.variance)
|
||||
self.test1.randomize()
|
||||
self.assertEqual(val, self.white.variance)
|
||||
|
||||
def test_fixing_randomize_parameter_handling(self):
|
||||
self.rbf.fix(warning=True)
|
||||
val = float(self.rbf.variance)
|
||||
self.test1.kern.randomize()
|
||||
self.assertEqual(val, self.rbf.variance)
|
||||
|
||||
def test_fixing_optimize(self):
|
||||
self.testmodel.kern.lengthscale.fix()
|
||||
val = float(self.testmodel.kern.lengthscale)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue