diff --git a/GPy/core/parameterization/index_operations.py b/GPy/core/parameterization/index_operations.py index 091b6372..b816e05f 100644 --- a/GPy/core/parameterization/index_operations.py +++ b/GPy/core/parameterization/index_operations.py @@ -57,6 +57,7 @@ class ParameterIndexOperations(object): You can give an offset to set an offset for the given indices in the index array, for multi-param handling. ''' + _offset = 0 def __init__(self, constraints=None): self._properties = IntArrayDict() if constraints is not None: @@ -120,6 +121,14 @@ class ParameterIndexOperations(object): return removed.astype(int) return numpy.array([]).astype(int) + def update(self, parameter_index_view, offset=0): + for i, v in parameter_index_view.iteritems(): + self.add(i, v+offset) + + + def copy(self): + return ParameterIndexOperations(dict(self.iteritems())) + def __getitem__(self, prop): return self._properties[prop] @@ -223,9 +232,9 @@ class ParameterIndexOperationsView(object): import pprint return pprint.pformat(dict(self.iteritems())) - def update(self, parameter_index_view): + def update(self, parameter_index_view, offset=0): for i, v in parameter_index_view.iteritems(): - self.add(i, v) + self.add(i, v+offset) def copy(self): diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index cfee60bd..65504652 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -72,6 +72,13 @@ class Parentable(object): def has_parent(self): return self._direct_parent_ is not None + def _notify_parent_change(self): + for p in self._parameters_: + p._parent_changed(self) + + def _parent_changed(self): + raise NotImplementedError, "shouldnt happen, Parentable objects need to be able to change their parent" + @property def _highest_parent_(self): if self._direct_parent_ is None: @@ -182,11 +189,9 @@ class Constrainable(Nameable, Indexable, Parameterizable): # Constrain operations -> done #=========================================================================== def _parent_changed(self, parent): - c = self.constraints from index_operations import ParameterIndexOperationsView self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size) - self.constraints.update(c) - del c + self._fixes_ = None for p in self._parameters_: p._parent_changed(parent) diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 80bf2959..a976eb93 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -87,17 +87,20 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable): elif param not in self._parameters_: # make sure the size is set if index is None: + self.constraints.update(param.constraints, self.size) self._parameters_.append(param) else: start = sum(p.size for p in self._parameters_[:index]) self.constraints.shift(start, param.size) + self.constraints.update(param.constraints, start) self._parameters_.insert(index, param) self.size += param.size else: raise RuntimeError, """Parameter exists already added and no copy made""" self._connect_parameters() - for p in self._parameters_: - p._parent_changed(self) + self._notify_parent_change() + self._connect_fixes() + def add_parameters(self, *parameters): """ @@ -120,7 +123,13 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable): param._direct_parent_ = None param._parent_index_ = None param._connect_fixes() + param._notify_parent_change() + pname = adjust_name_for_printing(param.name) + if pname in self._added_names_: + del self.__dict__[pname] self._connect_parameters() + #self._notify_parent_change() + self._connect_fixes() def _connect_parameters(self): # connect parameterlist to this parameterized object @@ -149,7 +158,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable): elif not (pname in not_unique): self.__dict__[pname] = p self._added_names_.add(pname) - self._connect_fixes() #=========================================================================== # Pickling operations diff --git a/GPy/testing/parameterized_tests.py b/GPy/testing/parameterized_tests.py new file mode 100644 index 00000000..ff57606a --- /dev/null +++ b/GPy/testing/parameterized_tests.py @@ -0,0 +1,93 @@ +''' +Created on Feb 13, 2014 + +@author: maxzwiessele +''' +import unittest +import GPy +import numpy as np + +class Test(unittest.TestCase): + + def setUp(self): + self.rbf = GPy.kern.rbf(1) + self.white = GPy.kern.white(1) + from GPy.core.parameterization import Param + from GPy.core.parameterization.transformations import Logistic + self.param = Param('param', np.random.rand(25,2), Logistic(0, 1)) + + self.test1 = GPy.core.Parameterized("test model") + self.test1.add_parameter(self.white) + self.test1.add_parameter(self.rbf, 0) + self.test1.add_parameter(self.param) + + def test_add_parameter(self): + self.assertEquals(self.rbf._parent_index_, 0) + self.assertEquals(self.white._parent_index_, 1) + pass + + def test_fixes(self): + self.white.fix(warning=False) + self.test1.remove_parameter(self.test1.param) + self.assertTrue(self.test1._has_fixes()) + + from GPy.core.parameterization.transformations import FIXED, UNFIXED + self.assertListEqual(self.test1._fixes_.tolist(),[UNFIXED,UNFIXED,FIXED]) + + self.test1.add_parameter(self.white, 0) + self.assertListEqual(self.test1._fixes_.tolist(),[FIXED,UNFIXED,UNFIXED]) + + + def test_remove_parameter(self): + from GPy.core.parameterization.transformations import FIXED, UNFIXED, __fixed__ + self.white.fix() + self.test1.remove_parameter(self.white) + self.assertIs(self.test1._fixes_,None) + + self.assertListEqual(self.white._fixes_.tolist(), [FIXED]) + self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops) + self.assertEquals(self.white.white.constraints._offset, 0) + self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops) + + self.test1.add_parameter(self.white, 0) + self.assertIs(self.test1.constraints, self.white.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops) + self.assertListEqual(self.test1.constraints[__fixed__].tolist(), [0]) + self.assertIs(self.white._fixes_,None) + self.assertListEqual(self.test1._fixes_.tolist(),[FIXED] + [UNFIXED] * 52) + self.test1.remove_parameter(self.white) + self.assertIs(self.test1._fixes_,None) + self.assertListEqual(self.white._fixes_.tolist(), [FIXED]) + self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops) + + def test_add_parameter_already_in_hirarchy(self): + self.test1.add_parameter(self.white._parameters_[0]) + + def test_default_constraints(self): + self.assertIs(self.rbf.rbf.variance.constraints._param_index_ops, self.rbf.constraints._param_index_ops) + self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops) + self.assertListEqual(self.rbf.constraints.indices()[0].tolist(), range(2)) + from GPy.core.parameterization.transformations import Logexp + kern = self.rbf+self.white + self.assertListEqual(kern.constraints[Logexp()].tolist(), range(3)) + + def test_constraints(self): + self.rbf.constrain(GPy.transformations.Square(), False) + self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(2)) + self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [2]) + + self.test1.remove_parameter(self.rbf) + self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), []) + + def test_constraints_views(self): + self.assertEqual(self.white.constraints._offset, 2) + self.assertEqual(self.rbf.constraints._offset, 0) + self.assertEqual(self.param.constraints._offset, 3) + +if __name__ == "__main__": + #import sys;sys.argv = ['', 'Test.test_add_parameter'] + unittest.main() \ No newline at end of file