diff --git a/GPy/core/parameter.py b/GPy/core/parameter.py index 4f064eb5..78e53d5b 100644 --- a/GPy/core/parameter.py +++ b/GPy/core/parameter.py @@ -223,6 +223,9 @@ class Param(ObservableArray, Nameable, Pickleable): else: self._direct_parent_._get_original(self).__setitem__(self._current_slice_, transform.initialize(self), update=False) self._highest_parent_._add_constrain(self, transform, warning) + for t in self._tied_to_me_.iterkeys(): + if transform not in self._highest_parent_._constraints_for_collect(t, t._raveled_index()): + t._direct_parent_._get_original(t)[t._current_slice_].constrain(transform, warning, update) if update: self._highest_parent_.parameters_changed() @@ -323,6 +326,9 @@ class Param(ObservableArray, Nameable, Pickleable): self._direct_parent_._get_original(self)._tied_to_ += [param] param._add_tie_listener(self) self._highest_parent_._set_fixed(self) + cs = self._highest_parent_._constraints_for(param, param._raveled_index()) + for cs in self._highest_parent_._constraints_for(param, param._raveled_index()): + [self.constrain(c, warning=False) for c in cs] # for t in self._tied_to_me_.keys(): # if t is not self: # t.untie(self) diff --git a/GPy/core/parameterized.py b/GPy/core/parameterized.py index fbfaae83..36ae547e 100644 --- a/GPy/core/parameterized.py +++ b/GPy/core/parameterized.py @@ -555,6 +555,10 @@ class Parameterized(Nameable, Pickleable, Observable): def _constraints_for(self, param, rav_index): # constraint for param given its internal rav_index return self.constraints.properties_for(rav_index+self._offset_for(param)) + def _constraints_for_collect(self, param, rav_index): + # constraint for param given its internal rav_index + cs = self._constraints_for(param, rav_index) + return set(itertools.chain(*cs)) #=========================================================================== # Get/set parameters: #=========================================================================== diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index 1fb6b848..19cc1dcf 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -20,13 +20,12 @@ class KernelTests(unittest.TestCase): K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2]) K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3]) K.rbf.lengthscale[2].constrain_fixed() - K.rbf.lengthscale[3].tie_to(K.rbf.variance) X = np.random.rand(5,5) Y = np.ones((5,1)) m = GPy.models.GPRegression(X,Y,K) - self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[1])) - self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[0])) - self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale.tie_to(m.kern.rbf.lengthscale)) + #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[1])) + #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[0])) + #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale.tie_to(m.kern.rbf.lengthscale)) import ipdb;ipdb.set_trace() self.assertTrue(m.checkgrad()) @@ -121,14 +120,15 @@ class KernelTests(unittest.TestCase): if __name__ == "__main__": - K = GPy.kern.rbf(5, ARD=True) - K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2]) - K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3]) - K.rbf.lengthscale[2].constrain_fixed() - K.rbf.lengthscale[2:].tie_to(K.rbf.variance) - X = np.random.rand(5,5) - Y = np.ones((5,1)) - m = GPy.models.GPRegression(X,Y,K) +# K = GPy.kern.rbf(5, ARD=True) +# K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2]) +# K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3]) +# K.rbf.lengthscale[2].constrain_fixed() +# +# K.rbf.lengthscale[2:].tie_to(K.rbf.variance) +# X = np.random.rand(5,5) +# Y = np.ones((5,1)) +# m = GPy.models.GPRegression(X,Y,K) - #print "Running unit tests, please be (very) patient..." - #unittest.main() + print "Running unit tests, please be (very) patient..." + unittest.main()