mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-14 22:42:37 +02:00
parameter adding and removing now fully functional according to tests, including fixes
This commit is contained in:
parent
f71af38505
commit
13212abd0b
4 changed files with 123 additions and 8 deletions
|
|
@ -57,6 +57,7 @@ class ParameterIndexOperations(object):
|
|||
You can give an offset to set an offset for the given indices in the
|
||||
index array, for multi-param handling.
|
||||
'''
|
||||
_offset = 0
|
||||
def __init__(self, constraints=None):
|
||||
self._properties = IntArrayDict()
|
||||
if constraints is not None:
|
||||
|
|
@ -120,6 +121,14 @@ class ParameterIndexOperations(object):
|
|||
return removed.astype(int)
|
||||
return numpy.array([]).astype(int)
|
||||
|
||||
def update(self, parameter_index_view, offset=0):
|
||||
for i, v in parameter_index_view.iteritems():
|
||||
self.add(i, v+offset)
|
||||
|
||||
|
||||
def copy(self):
|
||||
return ParameterIndexOperations(dict(self.iteritems()))
|
||||
|
||||
def __getitem__(self, prop):
|
||||
return self._properties[prop]
|
||||
|
||||
|
|
@ -223,9 +232,9 @@ class ParameterIndexOperationsView(object):
|
|||
import pprint
|
||||
return pprint.pformat(dict(self.iteritems()))
|
||||
|
||||
def update(self, parameter_index_view):
|
||||
def update(self, parameter_index_view, offset=0):
|
||||
for i, v in parameter_index_view.iteritems():
|
||||
self.add(i, v)
|
||||
self.add(i, v+offset)
|
||||
|
||||
|
||||
def copy(self):
|
||||
|
|
|
|||
|
|
@ -72,6 +72,13 @@ class Parentable(object):
|
|||
def has_parent(self):
|
||||
return self._direct_parent_ is not None
|
||||
|
||||
def _notify_parent_change(self):
|
||||
for p in self._parameters_:
|
||||
p._parent_changed(self)
|
||||
|
||||
def _parent_changed(self):
|
||||
raise NotImplementedError, "shouldnt happen, Parentable objects need to be able to change their parent"
|
||||
|
||||
@property
|
||||
def _highest_parent_(self):
|
||||
if self._direct_parent_ is None:
|
||||
|
|
@ -182,11 +189,9 @@ class Constrainable(Nameable, Indexable, Parameterizable):
|
|||
# Constrain operations -> done
|
||||
#===========================================================================
|
||||
def _parent_changed(self, parent):
|
||||
c = self.constraints
|
||||
from index_operations import ParameterIndexOperationsView
|
||||
self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size)
|
||||
self.constraints.update(c)
|
||||
del c
|
||||
self._fixes_ = None
|
||||
for p in self._parameters_:
|
||||
p._parent_changed(parent)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,17 +87,20 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
elif param not in self._parameters_:
|
||||
# make sure the size is set
|
||||
if index is None:
|
||||
self.constraints.update(param.constraints, self.size)
|
||||
self._parameters_.append(param)
|
||||
else:
|
||||
start = sum(p.size for p in self._parameters_[:index])
|
||||
self.constraints.shift(start, param.size)
|
||||
self.constraints.update(param.constraints, start)
|
||||
self._parameters_.insert(index, param)
|
||||
self.size += param.size
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
self._connect_parameters()
|
||||
for p in self._parameters_:
|
||||
p._parent_changed(self)
|
||||
self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
|
||||
|
||||
def add_parameters(self, *parameters):
|
||||
"""
|
||||
|
|
@ -120,7 +123,13 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
param._direct_parent_ = None
|
||||
param._parent_index_ = None
|
||||
param._connect_fixes()
|
||||
param._notify_parent_change()
|
||||
pname = adjust_name_for_printing(param.name)
|
||||
if pname in self._added_names_:
|
||||
del self.__dict__[pname]
|
||||
self._connect_parameters()
|
||||
#self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
|
||||
def _connect_parameters(self):
|
||||
# connect parameterlist to this parameterized object
|
||||
|
|
@ -149,7 +158,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
elif not (pname in not_unique):
|
||||
self.__dict__[pname] = p
|
||||
self._added_names_.add(pname)
|
||||
self._connect_fixes()
|
||||
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
|
|
|
|||
93
GPy/testing/parameterized_tests.py
Normal file
93
GPy/testing/parameterized_tests.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
'''
|
||||
Created on Feb 13, 2014
|
||||
|
||||
@author: maxzwiessele
|
||||
'''
|
||||
import unittest
|
||||
import GPy
|
||||
import numpy as np
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.rbf = GPy.kern.rbf(1)
|
||||
self.white = GPy.kern.white(1)
|
||||
from GPy.core.parameterization import Param
|
||||
from GPy.core.parameterization.transformations import Logistic
|
||||
self.param = Param('param', np.random.rand(25,2), Logistic(0, 1))
|
||||
|
||||
self.test1 = GPy.core.Parameterized("test model")
|
||||
self.test1.add_parameter(self.white)
|
||||
self.test1.add_parameter(self.rbf, 0)
|
||||
self.test1.add_parameter(self.param)
|
||||
|
||||
def test_add_parameter(self):
|
||||
self.assertEquals(self.rbf._parent_index_, 0)
|
||||
self.assertEquals(self.white._parent_index_, 1)
|
||||
pass
|
||||
|
||||
def test_fixes(self):
|
||||
self.white.fix(warning=False)
|
||||
self.test1.remove_parameter(self.test1.param)
|
||||
self.assertTrue(self.test1._has_fixes())
|
||||
|
||||
from GPy.core.parameterization.transformations import FIXED, UNFIXED
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[UNFIXED,UNFIXED,FIXED])
|
||||
|
||||
self.test1.add_parameter(self.white, 0)
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED,UNFIXED,UNFIXED])
|
||||
|
||||
|
||||
def test_remove_parameter(self):
|
||||
from GPy.core.parameterization.transformations import FIXED, UNFIXED, __fixed__
|
||||
self.white.fix()
|
||||
self.test1.remove_parameter(self.white)
|
||||
self.assertIs(self.test1._fixes_,None)
|
||||
|
||||
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
|
||||
self.assertEquals(self.white.white.constraints._offset, 0)
|
||||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||
|
||||
self.test1.add_parameter(self.white, 0)
|
||||
self.assertIs(self.test1.constraints, self.white.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||
self.assertListEqual(self.test1.constraints[__fixed__].tolist(), [0])
|
||||
self.assertIs(self.white._fixes_,None)
|
||||
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED] + [UNFIXED] * 52)
|
||||
self.test1.remove_parameter(self.white)
|
||||
self.assertIs(self.test1._fixes_,None)
|
||||
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||
|
||||
def test_add_parameter_already_in_hirarchy(self):
|
||||
self.test1.add_parameter(self.white._parameters_[0])
|
||||
|
||||
def test_default_constraints(self):
|
||||
self.assertIs(self.rbf.rbf.variance.constraints._param_index_ops, self.rbf.constraints._param_index_ops)
|
||||
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||
self.assertListEqual(self.rbf.constraints.indices()[0].tolist(), range(2))
|
||||
from GPy.core.parameterization.transformations import Logexp
|
||||
kern = self.rbf+self.white
|
||||
self.assertListEqual(kern.constraints[Logexp()].tolist(), range(3))
|
||||
|
||||
def test_constraints(self):
|
||||
self.rbf.constrain(GPy.transformations.Square(), False)
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(2))
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [2])
|
||||
|
||||
self.test1.remove_parameter(self.rbf)
|
||||
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), [])
|
||||
|
||||
def test_constraints_views(self):
|
||||
self.assertEqual(self.white.constraints._offset, 2)
|
||||
self.assertEqual(self.rbf.constraints._offset, 0)
|
||||
self.assertEqual(self.param.constraints._offset, 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
#import sys;sys.argv = ['', 'Test.test_add_parameter']
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue