mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 14:03:20 +02:00
parameter adding and removing now fully functional according to tests, including fixes
This commit is contained in:
parent
f71af38505
commit
13212abd0b
4 changed files with 123 additions and 8 deletions
|
|
@ -57,6 +57,7 @@ class ParameterIndexOperations(object):
|
||||||
You can give an offset to set an offset for the given indices in the
|
You can give an offset to set an offset for the given indices in the
|
||||||
index array, for multi-param handling.
|
index array, for multi-param handling.
|
||||||
'''
|
'''
|
||||||
|
_offset = 0
|
||||||
def __init__(self, constraints=None):
|
def __init__(self, constraints=None):
|
||||||
self._properties = IntArrayDict()
|
self._properties = IntArrayDict()
|
||||||
if constraints is not None:
|
if constraints is not None:
|
||||||
|
|
@ -120,6 +121,14 @@ class ParameterIndexOperations(object):
|
||||||
return removed.astype(int)
|
return removed.astype(int)
|
||||||
return numpy.array([]).astype(int)
|
return numpy.array([]).astype(int)
|
||||||
|
|
||||||
|
def update(self, parameter_index_view, offset=0):
|
||||||
|
for i, v in parameter_index_view.iteritems():
|
||||||
|
self.add(i, v+offset)
|
||||||
|
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return ParameterIndexOperations(dict(self.iteritems()))
|
||||||
|
|
||||||
def __getitem__(self, prop):
|
def __getitem__(self, prop):
|
||||||
return self._properties[prop]
|
return self._properties[prop]
|
||||||
|
|
||||||
|
|
@ -223,9 +232,9 @@ class ParameterIndexOperationsView(object):
|
||||||
import pprint
|
import pprint
|
||||||
return pprint.pformat(dict(self.iteritems()))
|
return pprint.pformat(dict(self.iteritems()))
|
||||||
|
|
||||||
def update(self, parameter_index_view):
|
def update(self, parameter_index_view, offset=0):
|
||||||
for i, v in parameter_index_view.iteritems():
|
for i, v in parameter_index_view.iteritems():
|
||||||
self.add(i, v)
|
self.add(i, v+offset)
|
||||||
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,13 @@ class Parentable(object):
|
||||||
def has_parent(self):
|
def has_parent(self):
|
||||||
return self._direct_parent_ is not None
|
return self._direct_parent_ is not None
|
||||||
|
|
||||||
|
def _notify_parent_change(self):
|
||||||
|
for p in self._parameters_:
|
||||||
|
p._parent_changed(self)
|
||||||
|
|
||||||
|
def _parent_changed(self):
|
||||||
|
raise NotImplementedError, "shouldnt happen, Parentable objects need to be able to change their parent"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _highest_parent_(self):
|
def _highest_parent_(self):
|
||||||
if self._direct_parent_ is None:
|
if self._direct_parent_ is None:
|
||||||
|
|
@ -182,11 +189,9 @@ class Constrainable(Nameable, Indexable, Parameterizable):
|
||||||
# Constrain operations -> done
|
# Constrain operations -> done
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _parent_changed(self, parent):
|
def _parent_changed(self, parent):
|
||||||
c = self.constraints
|
|
||||||
from index_operations import ParameterIndexOperationsView
|
from index_operations import ParameterIndexOperationsView
|
||||||
self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size)
|
self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size)
|
||||||
self.constraints.update(c)
|
self._fixes_ = None
|
||||||
del c
|
|
||||||
for p in self._parameters_:
|
for p in self._parameters_:
|
||||||
p._parent_changed(parent)
|
p._parent_changed(parent)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,17 +87,20 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||||
elif param not in self._parameters_:
|
elif param not in self._parameters_:
|
||||||
# make sure the size is set
|
# make sure the size is set
|
||||||
if index is None:
|
if index is None:
|
||||||
|
self.constraints.update(param.constraints, self.size)
|
||||||
self._parameters_.append(param)
|
self._parameters_.append(param)
|
||||||
else:
|
else:
|
||||||
start = sum(p.size for p in self._parameters_[:index])
|
start = sum(p.size for p in self._parameters_[:index])
|
||||||
self.constraints.shift(start, param.size)
|
self.constraints.shift(start, param.size)
|
||||||
|
self.constraints.update(param.constraints, start)
|
||||||
self._parameters_.insert(index, param)
|
self._parameters_.insert(index, param)
|
||||||
self.size += param.size
|
self.size += param.size
|
||||||
else:
|
else:
|
||||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||||
self._connect_parameters()
|
self._connect_parameters()
|
||||||
for p in self._parameters_:
|
self._notify_parent_change()
|
||||||
p._parent_changed(self)
|
self._connect_fixes()
|
||||||
|
|
||||||
|
|
||||||
def add_parameters(self, *parameters):
|
def add_parameters(self, *parameters):
|
||||||
"""
|
"""
|
||||||
|
|
@ -120,7 +123,13 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||||
param._direct_parent_ = None
|
param._direct_parent_ = None
|
||||||
param._parent_index_ = None
|
param._parent_index_ = None
|
||||||
param._connect_fixes()
|
param._connect_fixes()
|
||||||
|
param._notify_parent_change()
|
||||||
|
pname = adjust_name_for_printing(param.name)
|
||||||
|
if pname in self._added_names_:
|
||||||
|
del self.__dict__[pname]
|
||||||
self._connect_parameters()
|
self._connect_parameters()
|
||||||
|
#self._notify_parent_change()
|
||||||
|
self._connect_fixes()
|
||||||
|
|
||||||
def _connect_parameters(self):
|
def _connect_parameters(self):
|
||||||
# connect parameterlist to this parameterized object
|
# connect parameterlist to this parameterized object
|
||||||
|
|
@ -149,7 +158,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||||
elif not (pname in not_unique):
|
elif not (pname in not_unique):
|
||||||
self.__dict__[pname] = p
|
self.__dict__[pname] = p
|
||||||
self._added_names_.add(pname)
|
self._added_names_.add(pname)
|
||||||
self._connect_fixes()
|
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Pickling operations
|
# Pickling operations
|
||||||
|
|
|
||||||
93
GPy/testing/parameterized_tests.py
Normal file
93
GPy/testing/parameterized_tests.py
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
'''
|
||||||
|
Created on Feb 13, 2014
|
||||||
|
|
||||||
|
@author: maxzwiessele
|
||||||
|
'''
|
||||||
|
import unittest
|
||||||
|
import GPy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.rbf = GPy.kern.rbf(1)
|
||||||
|
self.white = GPy.kern.white(1)
|
||||||
|
from GPy.core.parameterization import Param
|
||||||
|
from GPy.core.parameterization.transformations import Logistic
|
||||||
|
self.param = Param('param', np.random.rand(25,2), Logistic(0, 1))
|
||||||
|
|
||||||
|
self.test1 = GPy.core.Parameterized("test model")
|
||||||
|
self.test1.add_parameter(self.white)
|
||||||
|
self.test1.add_parameter(self.rbf, 0)
|
||||||
|
self.test1.add_parameter(self.param)
|
||||||
|
|
||||||
|
def test_add_parameter(self):
|
||||||
|
self.assertEquals(self.rbf._parent_index_, 0)
|
||||||
|
self.assertEquals(self.white._parent_index_, 1)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_fixes(self):
|
||||||
|
self.white.fix(warning=False)
|
||||||
|
self.test1.remove_parameter(self.test1.param)
|
||||||
|
self.assertTrue(self.test1._has_fixes())
|
||||||
|
|
||||||
|
from GPy.core.parameterization.transformations import FIXED, UNFIXED
|
||||||
|
self.assertListEqual(self.test1._fixes_.tolist(),[UNFIXED,UNFIXED,FIXED])
|
||||||
|
|
||||||
|
self.test1.add_parameter(self.white, 0)
|
||||||
|
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED,UNFIXED,UNFIXED])
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_parameter(self):
|
||||||
|
from GPy.core.parameterization.transformations import FIXED, UNFIXED, __fixed__
|
||||||
|
self.white.fix()
|
||||||
|
self.test1.remove_parameter(self.white)
|
||||||
|
self.assertIs(self.test1._fixes_,None)
|
||||||
|
|
||||||
|
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||||
|
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
|
||||||
|
self.assertEquals(self.white.white.constraints._offset, 0)
|
||||||
|
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||||
|
|
||||||
|
self.test1.add_parameter(self.white, 0)
|
||||||
|
self.assertIs(self.test1.constraints, self.white.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||||
|
self.assertListEqual(self.test1.constraints[__fixed__].tolist(), [0])
|
||||||
|
self.assertIs(self.white._fixes_,None)
|
||||||
|
self.assertListEqual(self.test1._fixes_.tolist(),[FIXED] + [UNFIXED] * 52)
|
||||||
|
self.test1.remove_parameter(self.white)
|
||||||
|
self.assertIs(self.test1._fixes_,None)
|
||||||
|
self.assertListEqual(self.white._fixes_.tolist(), [FIXED])
|
||||||
|
self.assertIs(self.white.constraints,self.white.white.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.param.constraints._param_index_ops)
|
||||||
|
|
||||||
|
def test_add_parameter_already_in_hirarchy(self):
|
||||||
|
self.test1.add_parameter(self.white._parameters_[0])
|
||||||
|
|
||||||
|
def test_default_constraints(self):
|
||||||
|
self.assertIs(self.rbf.rbf.variance.constraints._param_index_ops, self.rbf.constraints._param_index_ops)
|
||||||
|
self.assertIs(self.test1.constraints, self.rbf.constraints._param_index_ops)
|
||||||
|
self.assertListEqual(self.rbf.constraints.indices()[0].tolist(), range(2))
|
||||||
|
from GPy.core.parameterization.transformations import Logexp
|
||||||
|
kern = self.rbf+self.white
|
||||||
|
self.assertListEqual(kern.constraints[Logexp()].tolist(), range(3))
|
||||||
|
|
||||||
|
def test_constraints(self):
|
||||||
|
self.rbf.constrain(GPy.transformations.Square(), False)
|
||||||
|
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), range(2))
|
||||||
|
self.assertListEqual(self.test1.constraints[GPy.transformations.Logexp()].tolist(), [2])
|
||||||
|
|
||||||
|
self.test1.remove_parameter(self.rbf)
|
||||||
|
self.assertListEqual(self.test1.constraints[GPy.transformations.Square()].tolist(), [])
|
||||||
|
|
||||||
|
def test_constraints_views(self):
|
||||||
|
self.assertEqual(self.white.constraints._offset, 2)
|
||||||
|
self.assertEqual(self.rbf.constraints._offset, 0)
|
||||||
|
self.assertEqual(self.param.constraints._offset, 3)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#import sys;sys.argv = ['', 'Test.test_add_parameter']
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue