[#133] fix: chainging constraint in __init__

This commit is contained in:
Max Zwiessele 2014-09-22 13:57:27 +01:00
parent bccd8e094a
commit ac4dbb851d
3 changed files with 28 additions and 3 deletions

View file

@ -636,10 +636,18 @@ class Indexable(Nameable, Observable):
""" """
From Parentable: From Parentable:
Called when the parent changed Called when the parent changed
update the constraints and priors view, so that
constraining is automized for the parent.
""" """
from index_operations import ParameterIndexOperationsView from index_operations import ParameterIndexOperationsView
self.constraints = ParameterIndexOperationsView(parent.constraints, parent._offset_for(self), self.size) #if getattr(self, "_in_init_"):
self.priors = ParameterIndexOperationsView(parent.priors, parent._offset_for(self), self.size) #import ipdb;ipdb.set_trace()
#self.constraints.update(param.constraints, start)
#self.priors.update(param.priors, start)
offset = parent._offset_for(self)
self.constraints = ParameterIndexOperationsView(parent.constraints, offset, self.size)
self.priors = ParameterIndexOperationsView(parent.priors, offset, self.size)
self._fixes_ = None self._fixes_ = None
for p in self.parameters: for p in self.parameters:
p._parent_changed(parent) p._parent_changed(parent)

View file

@ -149,6 +149,7 @@ class Parameterized(Parameterizable):
self.priors.update(param.priors, start) self.priors.update(param.priors, start)
self.parameters.insert(index, param) self.parameters.insert(index, param)
self._notify_parent_change()
param.add_observer(self, self._pass_through_notify_observers, -np.inf) param.add_observer(self, self._pass_through_notify_observers, -np.inf)
parent = self parent = self

View file

@ -8,7 +8,9 @@ import GPy
import numpy as np import numpy as np
from GPy.core.parameterization.parameter_core import HierarchyError from GPy.core.parameterization.parameter_core import HierarchyError
from GPy.core.parameterization.observable_array import ObsAr from GPy.core.parameterization.observable_array import ObsAr
from GPy.core.parameterization.transformations import NegativeLogexp from GPy.core.parameterization.transformations import NegativeLogexp, Logistic
from GPy.core.parameterization.parameterized import Parameterized
from GPy.core.parameterization.param import Param
class ArrayCoreTest(unittest.TestCase): class ArrayCoreTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -198,6 +200,20 @@ class ParameterizedTest(unittest.TestCase):
unfixed = self.testmodel.kern.unfix() unfixed = self.testmodel.kern.unfix()
self.assertListEqual(unfixed.tolist(), [0,1]) self.assertListEqual(unfixed.tolist(), [0,1])
def test_constraints_in_init(self):
class Test(Parameterized):
def __init__(self, name=None, parameters=[], *a, **kw):
super(Test, self).__init__(name=name)
self.x = Param('x', np.random.uniform(0,1,(3,4)))
self.x[0].constrain_bounded(0,1)
self.link_parameter(self.x)
self.x[1].fix()
t = Test()
c = {Logistic(0,1): np.array([0, 1, 2, 3]), 'fixed': np.array([4, 5, 6, 7])}
np.testing.assert_equal(t.x.constraints[Logistic(0,1)], c[Logistic(0,1)])
np.testing.assert_equal(t.x.constraints['fixed'], c['fixed'])
def test_printing(self): def test_printing(self):
print self.test1 print self.test1
print self.param print self.param