finish the problem of tie with constraints

This commit is contained in:
Zhenwen Dai 2014-09-05 11:21:47 +01:00
parent 4293f37cca
commit f093e60b5d
2 changed files with 13 additions and 16 deletions

View file

@ -162,11 +162,11 @@ class Parameterized(Parameterizable):
parent = parent._parent_ parent = parent._parent_
self._notify_parent_change() self._notify_parent_change()
self._connect_parameters()
# if not self._in_init_:
self._highest_parent_._notify_parent_change() self._highest_parent_._notify_parent_change()
if not self._in_init_: self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._connect_parameters() self._highest_parent_._connect_fixes()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._connect_fixes()
if isinstance(param,Parameterized): if isinstance(param,Parameterized):
from ties_and_remappings import Tie from ties_and_remappings import Tie

View file

@ -54,7 +54,7 @@ class Tie(Parameterized):
TODO: TODO:
1. Add the support for multiple parameter tie_together and tie_vector 1. Add the support for multiple parameter tie_together and tie_vector
2. Properly handling parameters with constraints 2. Properly handling parameters with constraints [DONE]
3. Properly handling the merging of two models [DONE] 3. Properly handling the merging of two models [DONE]
4. Properly handling initialization [DONE] 4. Properly handling initialization [DONE]
@ -104,7 +104,7 @@ class Tie(Parameterized):
self.updates = True self.updates = True
def _sync_val_group(self, plist): def _sync_val_group(self, plist):
val = np.asarray([p.param_array for p in plist]).mean() val = np.hstack([p.param_array.flat for p in plist]).mean()
def _set_val(p): def _set_val(p):
p[:] = val p[:] = val
for p in plist: for p in plist:
@ -124,9 +124,9 @@ class Tie(Parameterized):
if tie_con is not None: if tie_con is not None:
for p in plist: for p in plist:
if len(p.constraints.properties())!=1 or p.constraints[tie_con].size != p.size: if len(p.constraints.properties())!=1 or p.constraints[tie_con].size != p.size:
print 'WARNING: '+p.name+' have different constraints! They will be constrained '+str(cons[0])+'!' print 'WARNING: '+p.name+' have different constraints! They will be constrained '+str(tie_con)+'!'
p.constrain(cons[0]) p.constrain(tie_con)
return cons[0] return tie_con
elif hastie: elif hastie:
for p in plist: for p in plist:
if p.constraints.size>0: if p.constraints.size>0:
@ -149,8 +149,8 @@ class Tie(Parameterized):
def _get_labels(self, plist): def _get_labels(self, plist):
labels = [] labels = []
for p in plist: for p in plist:
self._traverse_param(lambda x: x.tie, p, labels) self._traverse_param(lambda x: x.tie.flat, p, labels)
return np.unique(np.asarray(labels)) return np.unique(np.hstack(labels))
def _set_labels(self, plist, labels): def _set_labels(self, plist, labels):
""" """
@ -234,7 +234,7 @@ class Tie(Parameterized):
def tie_together(self,plist): def tie_together(self,plist):
"""tie a list of parameters""" """tie a list of parameters"""
self.updates = False #self.updates = False
labels = self._get_labels(plist) labels = self._get_labels(plist)
val = self._sync_val_group(plist) val = self._sync_val_group(plist)
if labels[0]==0 and labels.size==1: if labels[0]==0 and labels.size==1:
@ -242,11 +242,8 @@ class Tie(Parameterized):
tie_labels = self._expand_tie_param(1) tie_labels = self._expand_tie_param(1)
self._set_labels(plist, tie_labels) self._set_labels(plist, tie_labels)
tie_con = self._sync_constraint_group(plist) tie_con = self._sync_constraint_group(plist)
print tie_con
if tie_con is not None: if tie_con is not None:
print self.tied_param[self.tied_param.tie==tie_labels[0]]
self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con) self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
# self.tied_param.constrain(tie_con)
else: else:
# Some of parameters has been tied already. # Some of parameters has been tied already.
# Merge the tie param # Merge the tie param
@ -259,7 +256,7 @@ class Tie(Parameterized):
self._sync_constraint_group(plist, True, tie_con) self._sync_constraint_group(plist, True, tie_con)
self._update_label_buf() self._update_label_buf()
self.tied_param[self.tied_param.tie==tie_labels[0]] = val self.tied_param[self.tied_param.tie==tie_labels[0]] = val
self.updates = True #self.updates = True
def untie(self,plist): def untie(self,plist):
"""Untie a list of parameters""" """Untie a list of parameters"""