parameter adding and removing now fully functional according to tests, including fixes

This commit is contained in:
Max Zwiessele 2014-02-13 16:44:45 +00:00
parent f71af38505
commit 13212abd0b
4 changed files with 123 additions and 8 deletions

View file

@ -57,6 +57,7 @@ class ParameterIndexOperations(object):
You can give an offset to set an offset for the given indices in the You can give an offset to set an offset for the given indices in the
index array, for multi-param handling. index array, for multi-param handling.
''' '''
_offset = 0
def __init__(self, constraints=None): def __init__(self, constraints=None):
self._properties = IntArrayDict() self._properties = IntArrayDict()
if constraints is not None: if constraints is not None:
@ -120,6 +121,14 @@ class ParameterIndexOperations(object):
return removed.astype(int) return removed.astype(int)
return numpy.array([]).astype(int) return numpy.array([]).astype(int)
def update(self, parameter_index_view, offset=0):
for i, v in parameter_index_view.iteritems():
self.add(i, v+offset)
def copy(self):
return ParameterIndexOperations(dict(self.iteritems()))
def __getitem__(self, prop): def __getitem__(self, prop):
return self._properties[prop] return self._properties[prop]
@ -223,9 +232,9 @@ class ParameterIndexOperationsView(object):
import pprint import pprint
return pprint.pformat(dict(self.iteritems())) return pprint.pformat(dict(self.iteritems()))
def update(self, parameter_index_view): def update(self, parameter_index_view, offset=0):
for i, v in parameter_index_view.iteritems(): for i, v in parameter_index_view.iteritems():
self.add(i, v) self.add(i, v+offset)
def copy(self): def copy(self):

View file

@ -72,6 +72,13 @@ class Parentable(object):
def has_parent(self): def has_parent(self):
return self._direct_parent_ is not None return self._direct_parent_ is not None
def _notify_parent_change(self):
for p in self._parameters_:
p._parent_changed(self)
def _parent_changed(self):
raise NotImplementedError, "shouldnt happen, Parentable objects need to be able to change their parent"
@property @property
def _highest_parent_(self): def _highest_parent_(self):
if self._direct_parent_ is None: if self._direct_parent_ is None:
@ -182,11 +189,9 @@ class Constrainable(Nameable, Indexable, Parameterizable):
# Constrain operations -> done # Constrain operations -> done
#=========================================================================== #===========================================================================
def _parent_changed(self, parent): def _parent_changed(self, parent):
c = self.constraints
from index_operations import ParameterIndexOperationsView from index_operations import ParameterIndexOperationsView
self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size) self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size)
self.constraints.update(c) self._fixes_ = None
del c
for p in self._parameters_: for p in self._parameters_:
p._parent_changed(parent) p._parent_changed(parent)

View file

@ -87,17 +87,20 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
elif param not in self._parameters_: elif param not in self._parameters_:
# make sure the size is set # make sure the size is set
if index is None: if index is None:
self.constraints.update(param.constraints, self.size)
self._parameters_.append(param) self._parameters_.append(param)
else: else:
start = sum(p.size for p in self._parameters_[:index]) start = sum(p.size for p in self._parameters_[:index])
self.constraints.shift(start, param.size) self.constraints.shift(start, param.size)
self.constraints.update(param.constraints, start)
self._parameters_.insert(index, param) self._parameters_.insert(index, param)
self.size += param.size self.size += param.size
else: else:
raise RuntimeError, """Parameter exists already added and no copy made""" raise RuntimeError, """Parameter exists already added and no copy made"""
self._connect_parameters() self._connect_parameters()
for p in self._parameters_: self._notify_parent_change()
p._parent_changed(self) self._connect_fixes()
def add_parameters(self, *parameters): def add_parameters(self, *parameters):
""" """
@ -120,7 +123,13 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
param._direct_parent_ = None param._direct_parent_ = None
param._parent_index_ = None param._parent_index_ = None
param._connect_fixes() param._connect_fixes()
param._notify_parent_change()
pname = adjust_name_for_printing(param.name)
if pname in self._added_names_:
del self.__dict__[pname]
self._connect_parameters() self._connect_parameters()
#self._notify_parent_change()
self._connect_fixes()
def _connect_parameters(self): def _connect_parameters(self):
# connect parameterlist to this parameterized object # connect parameterlist to this parameterized object
@ -149,7 +158,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
elif not (pname in not_unique): elif not (pname in not_unique):
self.__dict__[pname] = p self.__dict__[pname] = p
self._added_names_.add(pname) self._added_names_.add(pname)
self._connect_fixes()
#=========================================================================== #===========================================================================
# Pickling operations # Pickling operations

View file

@ -0,0 +1,93 @@
'''
Created on Feb 13, 2014
@author: maxzwiessele
'''
import unittest
import GPy
import numpy as np
class Test(unittest.TestCase):
def setUp(self):
self.rbf = GPy.kern.rbf(1)
self.white = GPy.kern.white(1)
from GPy.core.parameterization import Param
from GPy.core.parameterization.transformations import Logistic
self.param = Param('param', np.random.rand(25,2), Logistic(0, 1))
self.test1 = GPy.core.Parameterized("test model")
self.test1.add_parameter(self.white)
self.test1.add_parameter(self.rbf, 0)
self.test1.add_parameter(self.param)
def test_add_parameter(self):
self.assertEquals(self.rbf._parent_index_, 0)
self.assertEquals(self.white._parent_index_, 1)
pass
def test_fixes(self):
self.white.fix(warning=False)
self.test1.remove_parameter(self.test1.param)
self.assertTrue(self.test1._has_fixes())
from GPy.core.parameterization.transformations import FIXED, UNFIXED
self.assertListEqual(self.test1._fixes_.tolist(),[UNFIXED,UNFIXED,FIXED])
self.test1.add_parameter(self.white, 0)
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED,UNFIXED,UNFIXED])
def test_remove_parameter(self):
from GPy.core.parameterization.transformations import FIXED, UNFIXED, __fixed__
self.white.fix()
self.test1.remove_parameter(self.white)
self.assertIs(self.test1._fixes_,None)
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
self.assertEquals(self.white.white.constraints._offset, 0)
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
self.test1.add_parameter(self.white, 0)
self.assertIs(self.test1.constraints, self.white.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
self.assertListEqual(self.test1.constraints[__fixed__].tolist(), [0])
self.assertIs(self.white._fixes_,None)
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED] + [UNFIXED] * 52)
self.test1.remove_parameter(self.white)
self.assertIs(self.test1._fixes_,None)
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
def test_add_parameter_already_in_hirarchy(self):
self.test1.add_parameter(self.white._parameters_[0])
def test_default_constraints(self):
self.assertIs(self.rbf.rbf.variance.constraints._param_index_ops, self.rbf.constraints._param_index_ops)
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
self.assertListEqual(self.rbf.constraints.indices()[0].tolist(), range(2))
from GPy.core.parameterization.transformations import Logexp
kern = self.rbf+self.white
self.assertListEqual(kern.constraints[Logexp()].tolist(), range(3))
def test_constraints(self):
self.rbf.constrain(GPy.transformations.Square(), False)
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(2))
self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [2])
self.test1.remove_parameter(self.rbf)
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), [])
def test_constraints_views(self):
self.assertEqual(self.white.constraints._offset, 2)
self.assertEqual(self.rbf.constraints._offset, 0)
self.assertEqual(self.param.constraints._offset, 3)
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_add_parameter']
unittest.main()