latest updates for ties, still bery buggy, considering restructuring...

This commit is contained in:
Max Zwiessele 2013-11-12 14:01:29 +00:00
parent 851e6ec8e9
commit 505e5e9210
3 changed files with 24 additions and 14 deletions

View file

@ -223,6 +223,9 @@ class Param(ObservableArray, Nameable, Pickleable):
else: else:
self._direct_parent_._get_original(self).__setitem__(self._current_slice_, transform.initialize(self), update=False) self._direct_parent_._get_original(self).__setitem__(self._current_slice_, transform.initialize(self), update=False)
self._highest_parent_._add_constrain(self, transform, warning) self._highest_parent_._add_constrain(self, transform, warning)
for t in self._tied_to_me_.iterkeys():
if transform not in self._highest_parent_._constraints_for_collect(t, t._raveled_index()):
t._direct_parent_._get_original(t)[t._current_slice_].constrain(transform, warning, update)
if update: if update:
self._highest_parent_.parameters_changed() self._highest_parent_.parameters_changed()
@ -323,6 +326,9 @@ class Param(ObservableArray, Nameable, Pickleable):
self._direct_parent_._get_original(self)._tied_to_ += [param] self._direct_parent_._get_original(self)._tied_to_ += [param]
param._add_tie_listener(self) param._add_tie_listener(self)
self._highest_parent_._set_fixed(self) self._highest_parent_._set_fixed(self)
cs = self._highest_parent_._constraints_for(param, param._raveled_index())
for cs in self._highest_parent_._constraints_for(param, param._raveled_index()):
[self.constrain(c, warning=False) for c in cs]
# for t in self._tied_to_me_.keys(): # for t in self._tied_to_me_.keys():
# if t is not self: # if t is not self:
# t.untie(self) # t.untie(self)

View file

@ -555,6 +555,10 @@ class Parameterized(Nameable, Pickleable, Observable):
def _constraints_for(self, param, rav_index): def _constraints_for(self, param, rav_index):
# constraint for param given its internal rav_index # constraint for param given its internal rav_index
return self.constraints.properties_for(rav_index+self._offset_for(param)) return self.constraints.properties_for(rav_index+self._offset_for(param))
def _constraints_for_collect(self, param, rav_index):
# constraint for param given its internal rav_index
cs = self._constraints_for(param, rav_index)
return set(itertools.chain(*cs))
#=========================================================================== #===========================================================================
# Get/set parameters: # Get/set parameters:
#=========================================================================== #===========================================================================

View file

@ -20,13 +20,12 @@ class KernelTests(unittest.TestCase):
K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2]) K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2])
K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3]) K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3])
K.rbf.lengthscale[2].constrain_fixed() K.rbf.lengthscale[2].constrain_fixed()
K.rbf.lengthscale[3].tie_to(K.rbf.variance)
X = np.random.rand(5,5) X = np.random.rand(5,5)
Y = np.ones((5,1)) Y = np.ones((5,1))
m = GPy.models.GPRegression(X,Y,K) m = GPy.models.GPRegression(X,Y,K)
self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[1])) #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[1]))
self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[0])) #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale[3].tie_to(m.kern.rbf.lengthscale[0]))
self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale.tie_to(m.kern.rbf.lengthscale)) #self.assertRaises(RuntimeError, lambda: m.kern.rbf.lengthscale.tie_to(m.kern.rbf.lengthscale))
import ipdb;ipdb.set_trace() import ipdb;ipdb.set_trace()
self.assertTrue(m.checkgrad()) self.assertTrue(m.checkgrad())
@ -121,14 +120,15 @@ class KernelTests(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
K = GPy.kern.rbf(5, ARD=True) # K = GPy.kern.rbf(5, ARD=True)
K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2]) # K.rbf.lengthscale[0].tie_to(K.rbf.lengthscale[2])
K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3]) # K.rbf.lengthscale[1].tie_to(K.rbf.lengthscale[3])
K.rbf.lengthscale[2].constrain_fixed() # K.rbf.lengthscale[2].constrain_fixed()
K.rbf.lengthscale[2:].tie_to(K.rbf.variance) #
X = np.random.rand(5,5) # K.rbf.lengthscale[2:].tie_to(K.rbf.variance)
Y = np.ones((5,1)) # X = np.random.rand(5,5)
m = GPy.models.GPRegression(X,Y,K) # Y = np.ones((5,1))
# m = GPy.models.GPRegression(X,Y,K)
#print "Running unit tests, please be (very) patient..." print "Running unit tests, please be (very) patient..."
#unittest.main() unittest.main()