mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
finish the problem of tie with constraints
This commit is contained in:
parent
4293f37cca
commit
f093e60b5d
2 changed files with 13 additions and 16 deletions
|
|
@ -162,11 +162,11 @@ class Parameterized(Parameterizable):
|
||||||
parent = parent._parent_
|
parent = parent._parent_
|
||||||
|
|
||||||
self._notify_parent_change()
|
self._notify_parent_change()
|
||||||
|
self._connect_parameters()
|
||||||
|
# if not self._in_init_:
|
||||||
self._highest_parent_._notify_parent_change()
|
self._highest_parent_._notify_parent_change()
|
||||||
if not self._in_init_:
|
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
||||||
self._connect_parameters()
|
self._highest_parent_._connect_fixes()
|
||||||
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
|
||||||
self._highest_parent_._connect_fixes()
|
|
||||||
|
|
||||||
if isinstance(param,Parameterized):
|
if isinstance(param,Parameterized):
|
||||||
from ties_and_remappings import Tie
|
from ties_and_remappings import Tie
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class Tie(Parameterized):
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
1. Add the support for multiple parameter tie_together and tie_vector
|
1. Add the support for multiple parameter tie_together and tie_vector
|
||||||
2. Properly handling parameters with constraints
|
2. Properly handling parameters with constraints [DONE]
|
||||||
3. Properly handling the merging of two models [DONE]
|
3. Properly handling the merging of two models [DONE]
|
||||||
4. Properly handling initialization [DONE]
|
4. Properly handling initialization [DONE]
|
||||||
|
|
||||||
|
|
@ -104,7 +104,7 @@ class Tie(Parameterized):
|
||||||
self.updates = True
|
self.updates = True
|
||||||
|
|
||||||
def _sync_val_group(self, plist):
|
def _sync_val_group(self, plist):
|
||||||
val = np.asarray([p.param_array for p in plist]).mean()
|
val = np.hstack([p.param_array.flat for p in plist]).mean()
|
||||||
def _set_val(p):
|
def _set_val(p):
|
||||||
p[:] = val
|
p[:] = val
|
||||||
for p in plist:
|
for p in plist:
|
||||||
|
|
@ -124,9 +124,9 @@ class Tie(Parameterized):
|
||||||
if tie_con is not None:
|
if tie_con is not None:
|
||||||
for p in plist:
|
for p in plist:
|
||||||
if len(p.constraints.properties())!=1 or p.constraints[tie_con].size != p.size:
|
if len(p.constraints.properties())!=1 or p.constraints[tie_con].size != p.size:
|
||||||
print 'WARNING: '+p.name+' have different constraints! They will be constrained '+str(cons[0])+'!'
|
print 'WARNING: '+p.name+' have different constraints! They will be constrained '+str(tie_con)+'!'
|
||||||
p.constrain(cons[0])
|
p.constrain(tie_con)
|
||||||
return cons[0]
|
return tie_con
|
||||||
elif hastie:
|
elif hastie:
|
||||||
for p in plist:
|
for p in plist:
|
||||||
if p.constraints.size>0:
|
if p.constraints.size>0:
|
||||||
|
|
@ -149,8 +149,8 @@ class Tie(Parameterized):
|
||||||
def _get_labels(self, plist):
|
def _get_labels(self, plist):
|
||||||
labels = []
|
labels = []
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(lambda x: x.tie, p, labels)
|
self._traverse_param(lambda x: x.tie.flat, p, labels)
|
||||||
return np.unique(np.asarray(labels))
|
return np.unique(np.hstack(labels))
|
||||||
|
|
||||||
def _set_labels(self, plist, labels):
|
def _set_labels(self, plist, labels):
|
||||||
"""
|
"""
|
||||||
|
|
@ -234,7 +234,7 @@ class Tie(Parameterized):
|
||||||
|
|
||||||
def tie_together(self,plist):
|
def tie_together(self,plist):
|
||||||
"""tie a list of parameters"""
|
"""tie a list of parameters"""
|
||||||
self.updates = False
|
#self.updates = False
|
||||||
labels = self._get_labels(plist)
|
labels = self._get_labels(plist)
|
||||||
val = self._sync_val_group(plist)
|
val = self._sync_val_group(plist)
|
||||||
if labels[0]==0 and labels.size==1:
|
if labels[0]==0 and labels.size==1:
|
||||||
|
|
@ -242,11 +242,8 @@ class Tie(Parameterized):
|
||||||
tie_labels = self._expand_tie_param(1)
|
tie_labels = self._expand_tie_param(1)
|
||||||
self._set_labels(plist, tie_labels)
|
self._set_labels(plist, tie_labels)
|
||||||
tie_con = self._sync_constraint_group(plist)
|
tie_con = self._sync_constraint_group(plist)
|
||||||
print tie_con
|
|
||||||
if tie_con is not None:
|
if tie_con is not None:
|
||||||
print self.tied_param[self.tied_param.tie==tie_labels[0]]
|
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
|
self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
|
||||||
# self.tied_param.constrain(tie_con)
|
|
||||||
else:
|
else:
|
||||||
# Some of parameters has been tied already.
|
# Some of parameters has been tied already.
|
||||||
# Merge the tie param
|
# Merge the tie param
|
||||||
|
|
@ -259,7 +256,7 @@ class Tie(Parameterized):
|
||||||
self._sync_constraint_group(plist, True, tie_con)
|
self._sync_constraint_group(plist, True, tie_con)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||||
self.updates = True
|
#self.updates = True
|
||||||
|
|
||||||
def untie(self,plist):
|
def untie(self,plist):
|
||||||
"""Untie a list of parameters"""
|
"""Untie a list of parameters"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue