redesign the tie framework

This commit is contained in:
Zhenwen Dai 2014-09-03 15:56:33 +01:00
parent 1abb3087ae
commit 2463a954f7
3 changed files with 214 additions and 109 deletions

View file

@ -84,6 +84,7 @@ class Param(Parameterizable, ObsAr):
self._gradient_array_ = getattr(obj, '_gradient_array_', None) self._gradient_array_ = getattr(obj, '_gradient_array_', None)
self.constraints = getattr(obj, 'constraints', None) self.constraints = getattr(obj, 'constraints', None)
self.priors = getattr(obj, 'priors', None) self.priors = getattr(obj, 'priors', None)
self._tie_ = getattr(obj, '_tie_', None)
@property @property
def param_array(self): def param_array(self):
@ -115,6 +116,16 @@ class Param(Parameterizable, ObsAr):
def gradient(self, val): def gradient(self, val):
self._gradient_array_[:] = val self._gradient_array_[:] = val
@property
def tie(self):
if getattr(self, '_tie_', None) is None:
self._tie_ = numpy.zeros(self._realshape_, dtype=numpy.uint32)
return self._tie_
@tie.setter
def tie(self, val):
self._tie_[:] = val
#=========================================================================== #===========================================================================
# Array operations -> done # Array operations -> done
#=========================================================================== #===========================================================================
@ -127,6 +138,7 @@ class Param(Parameterizable, ObsAr):
try: try:
new_arr._current_slice_ = s new_arr._current_slice_ = s
new_arr._gradient_array_ = self.gradient[s] new_arr._gradient_array_ = self.gradient[s]
new_arr._tie_ = self.tie[s]
new_arr._original_ = self.base is new_arr.base new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc except AttributeError: pass # returning 0d array or float, double etc
return new_arr return new_arr
@ -237,7 +249,7 @@ class Param(Parameterizable, ObsAr):
def _ties_str(self): def _ties_str(self):
return [''] return ['']
def _ties_for(self, ravi): def _ties_for(self, ravi):
return [['N/A']]*ravi.size return [['N/A' if self.tie[i]==0 else str(self.tie[i])] for i in xrange(ravi.size)]
def __repr__(self, *args, **kwargs): def __repr__(self, *args, **kwargs):
name = "\033[1m{x:s}\033[0;0m:\n".format( name = "\033[1m{x:s}\033[0;0m:\n".format(
x=self.hierarchy_name()) x=self.hierarchy_name())

View file

@ -524,7 +524,7 @@ class Indexable(Nameable, Observable):
return True return True
def tie_together(self): def tie_together(self):
self._highest_parent_.tie.add_tied_parameter(self) self._highest_parent_.tie.tie_together([self])
self._highest_parent_._set_fixed(self,self._raveled_index()) self._highest_parent_._set_fixed(self,self._raveled_index())
self._trigger_params_changed() self._trigger_params_changed()

View file

@ -42,6 +42,8 @@ class Tie(Parameterized):
=====Implementation Details===== =====Implementation Details=====
The *Tie* object should only exist on the top of param tree (the highest parent). The *Tie* object should only exist on the top of param tree (the highest parent).
Each tied param object has the attribute _tie_ which stores the labels for tied parameters.
self.label_buf: self.label_buf:
It uses a label buffer that has the same length as all the parameters (self._highest_parent_.param_array). It uses a label buffer that has the same length as all the parameters (self._highest_parent_.param_array).
The buffer keeps track of all the tied parameters. All the tied parameters have a label (an interger) higher The buffer keeps track of all the tied parameters. All the tied parameters have a label (an interger) higher
@ -53,9 +55,10 @@ class Tie(Parameterized):
================================ ================================
TODO: TODO:
1. Properly handling parameters with constraints 1. Add the support for multiple parameter tie_together and tie_vector
2. Properly handling the merging of two models 2. Properly handling parameters with constraints
3. Properly handling initialization 3. Properly handling the merging of two models
4. Properly handling initialization
""" """
def __init__(self, name='tie'): def __init__(self, name='tie'):
@ -66,126 +69,218 @@ class Tie(Parameterized):
self.tied_param = None self.tied_param = None
# The buffer keeps track of tie status # The buffer keeps track of tie status
self.label_buf = None self.label_buf = None
# The global indices of the 'tied' param
self.buf_idx = None self.buf_idx = None
# A boolean array indicating non-tied parameters
self._tie_ = None
def getTieFlag(self, p=None): def _get_raveled_index(self,plist):
if self.tied_param is None: indices = []
if self._tie_ is None or self._tie_.size != self._highest_parent_.param_array.size: for p in plist:
self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool) indices.extend(self._highest_parent_._raveled_index_for(p))
if p is not None: return indices
return self._tie_[p._highest_parent_._raveled_index_for(p)]
return self._tie_
def _init_labelBuf(self):
if self.label_buf is None:
self.label_buf = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int)
if self._tie_ is None or self._tie_.size != self._highest_parent_.param_array.size:
self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
def _updateTieFlag(self):
if self._tie_.size != self.label_buf.size:
self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
self._tie_[self.label_buf>0] = False
self._tie_[self.buf_idx] = True
def add_tied_parameter(self, p, p2=None):
"""
Tie the list of parameters p together (p2==None) or
Tie the list of parameters p with the list of parameters p2 (p2!=None)
"""
self._init_labelBuf()
if p2 is None:
idx = self._highest_parent_._raveled_index_for(p)
val = self._sync_val_group(idx)
if np.all(self.label_buf[idx]==0):
# None of p has been tied before.
tie_idx = self._expandTieParam(1)
print tie_idx
tie_id = self.label_buf.max()+1
self.label_buf[tie_idx] = tie_id
else:
b = self.label_buf[idx]
ids = np.unique(b[b>0])
tie_id, tie_idx = self._merge_tie_param(ids)
self._highest_parent_.param_array[tie_idx] = val
idx = self._highest_parent_._raveled_index_for(p)
self.label_buf[idx] = tie_id
else:
pass
self._updateTieFlag()
def _merge_tie_param(self, ids):
"""Merge the tie parameters with ids in the list."""
if len(ids)==1:
id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==ids[0]][0]
return ids[0],id_final_idx
id_final = ids[0]
ids_rm = ids[1:]
label_buf_param = self.label_buf[self.buf_idx]
idx_param = [np.where(label_buf_param==i)[0][0] for i in ids_rm]
self._removeTieParam(idx_param)
[np.put(self.label_buf, np.where(self.label_buf==i), id_final) for i in ids_rm]
id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==id_final][0]
return id_final, id_final_idx
def _sync_val_group(self, idx): def _sync_val_group(self, idx):
self._highest_parent_.param_array[idx] = self._highest_parent_.param_array[idx].mean() self._highest_parent_.param_array[idx] = self._highest_parent_.param_array[idx].mean()
return self._highest_parent_.param_array[idx][0] return self._highest_parent_.param_array[idx][0]
def _expandTieParam(self, num): def _traverse_param(self, func, p, res):
"""
Traverse a param tree starting with *p*
Apply *func* to every leaves (param objects),
and collect return values into *res*
"""
if isinstance(p, Param):
res.append(func(p))
else:
for pc in p.parameters:
self._traverse_param(func,pc,res)
def _get_labels(self,idx):
if self.label_buf is None:
self.label_buf = np.zeros((self._highest_parent_.size,),dtype=np.uint32)
return np.unique(self.label_buf[idx])
def _set_labels(self, plist, labels):
"""
If there is only one label, set all the param objects to that label,
otherwise each parameter take a label.
"""
def _set_l1(p):
p.tie[:] = labels[0]
if len(labels)==1:
for p in plist:
self._traverse_param(_set_l1, p, [])
def _replace_labels(self, p, label_pairs):
def _replace_l(p):
for l1,l2 in label_pairs:
p.tie[p.tie==l1] = l2
self._traverse_param(_replace_l, p, [])
def _expand_tie_param(self, num):
"""Expand the tie param with the number of *num* parameters""" """Expand the tie param with the number of *num* parameters"""
if self.tied_param is None: if self.tied_param is None:
start_label = 1
new_buf = np.empty((num,)) new_buf = np.empty((num,))
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:] = range(start_label,start_label+num)
else: else:
start_label = self.tied_param.tie.max()+1
new_buf = np.empty((self.tied_param.size+num,)) new_buf = np.empty((self.tied_param.size+num,))
new_buf[:self.tied_param.size] = self.tied_param.param_array.copy() new_buf[:self.tied_param.size] = self.tied_param.param_array.copy()
old_tie_ = self.tied_param.tie.copy()
old_size = self.tied_param.size
self.remove_parameter(self.tied_param) self.remove_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf) self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:old_size] = old_tie_
self.tied_param.tie[old_size:] = range(start_label,start_label+num)
self.add_parameter(self.tied_param) self.add_parameter(self.tied_param)
buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param) return range(start_label,start_label+num)
self._expand_label_buf(self.buf_idx, buf_idx_new)
self.buf_idx = buf_idx_new
return self.buf_idx[-num:]
def _removeTieParam(self, idx): def _remove_tie_param(self, labels):
"""idx within tied_param""" """Remove the tie param corresponding to *labels*"""
new_buf = np.empty((self.tied_param.size-len(idx),)) if len(labels) == self.tied_param.size:
bool_list = np.ones((self.tied_param.size,),dtype=np.bool) self.remove_parameter(self.tied_param)
bool_list[idx] = False self.tied_param = None
new_buf[:] = self.tied_param.param_array[bool_list]
self.remove_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf)
self.add_parameter(self.tied_param)
buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param)
self._shrink_label_buf(self.buf_idx, buf_idx_new, bool_list)
self.buf_idx = buf_idx_new
def _expand_label_buf(self, idx_old, idx_new):
"""Expand label buffer accordingly"""
if idx_old is None:
self.label_buf = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int)
else: else:
bool_old = np.zeros((self.label_buf.size,),dtype=np.bool) new_buf = np.empty((self.tied_param.size-len(labels),))
bool_old[idx_old] = True idx = np.logical_not(np.in1d(self.tied_param.tie,labels))
bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool) new_buf[:] = self.tied_param[idx]
bool_new[idx_new] = True old_tie_ = self.tied_param.tie.copy()
label_buf_new = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int) self.remove_parameter(self.tied_param)
label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)] self.tied_param = Param('tied',new_buf)
label_buf_new[idx_new[:len(idx_old)]] = self.label_buf[idx_old] self.tied_param.tie[:] = old_tie_[idx]
self.label_buf = label_buf_new self.add_parameter(self.tied_param)
def _shrink_label_buf(self, idx_old, idx_new, bool_list):
bool_old = np.zeros((self.label_buf.size,),dtype=np.bool) def _merge_tie_labels(self, labels):
bool_old[idx_old] = True """Merge all the labels in the list to the first one"""
bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool) if len(labels)<2:
bool_new[idx_new] = True return
label_buf_new = np.empty(self._highest_parent_.param_array.shape, dtype=np.int) self._remove_tie_param(labels[1:])
label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)] self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
label_buf_new[idx_new] = self.label_buf[idx_old[bool_list]]
self.label_buf = label_buf_new def _update_label_buf(self):
if self.tied_param is None:
self.label_buf = None
self.buf_idx = None
else:
self.label_buf = np.zeros((self._highest_parent_.size,),dtype=np.uint32)
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie), self._highest_parent_, [])
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
def tie_together(self,plist):
"""tie a list of parameters"""
indices = self._get_raveled_index(plist)
labels = self._get_labels(indices)
val = self._sync_val_group(indices)
if labels[0]==0 and labels.size==1:
# None of parameters in plist has been tied before.
tie_labels = self._expand_tie_param(1)
self._set_labels(plist, tie_labels)
else:
# Some of parameters has been tied already.
# Merge the tie param
tie_labels = labels[labels>0]
if tie_labels.size>1:
self._merge_tie_labels(tie_labels)
self._set_labels(plist, [tie_labels[0]])
self._update_label_buf()
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
# def add_tied_parameter(self, p, p2=None):
# """
# Tie the list of parameters p together (p2==None) or
# Tie the list of parameters p with the list of parameters p2 (p2!=None)
# """
# self._init_labelBuf()
# if p2 is None:
# idx = self._highest_parent_._raveled_index_for(p)
# val = self._sync_val_group(idx)
# if np.all(self.label_buf[idx]==0):
# # None of p has been tied before.
# tie_idx = self._expandTieParam(1)
# print tie_idx
# tie_id = self.label_buf.max()+1
# self.label_buf[tie_idx] = tie_id
# else:
# b = self.label_buf[idx]
# ids = np.unique(b[b>0])
# tie_id, tie_idx = self._merge_tie_param(ids)
# self._highest_parent_.param_array[tie_idx] = val
# idx = self._highest_parent_._raveled_index_for(p)
# self.label_buf[idx] = tie_id
# else:
# pass
# self._updateTieFlag()
#
# def _merge_tie_param(self, ids):
# """Merge the tie parameters with ids in the list."""
# if len(ids)==1:
# id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==ids[0]][0]
# return ids[0],id_final_idx
# id_final = ids[0]
# ids_rm = ids[1:]
# label_buf_param = self.label_buf[self.buf_idx]
# idx_param = [np.where(label_buf_param==i)[0][0] for i in ids_rm]
# self._removeTieParam(idx_param)
# [np.put(self.label_buf, np.where(self.label_buf==i), id_final) for i in ids_rm]
# id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==id_final][0]
# return id_final, id_final_idx
#
# def _sync_val_group(self, idx):
# self._highest_parent_.param_array[idx] = self._highest_parent_.param_array[idx].mean()
# return self._highest_parent_.param_array[idx][0]
#
# def _expandTieParam(self, num):
# """Expand the tie param with the number of *num* parameters"""
# if self.tied_param is None:
# new_buf = np.empty((num,))
# else:
# new_buf = np.empty((self.tied_param.size+num,))
# new_buf[:self.tied_param.size] = self.tied_param.param_array.copy()
# self.remove_parameter(self.tied_param)
# self.tied_param = Param('tied',new_buf)
# self.add_parameter(self.tied_param)
# buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param)
# self._expand_label_buf(self.buf_idx, buf_idx_new)
# self.buf_idx = buf_idx_new
# return self.buf_idx[-num:]
#
# def _removeTieParam(self, idx):
# """idx within tied_param"""
# new_buf = np.empty((self.tied_param.size-len(idx),))
# bool_list = np.ones((self.tied_param.size,),dtype=np.bool)
# bool_list[idx] = False
# new_buf[:] = self.tied_param.param_array[bool_list]
# self.remove_parameter(self.tied_param)
# self.tied_param = Param('tied',new_buf)
# self.add_parameter(self.tied_param)
# buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param)
# self._shrink_label_buf(self.buf_idx, buf_idx_new, bool_list)
# self.buf_idx = buf_idx_new
#
# def _expand_label_buf(self, idx_old, idx_new):
# """Expand label buffer accordingly"""
# if idx_old is None:
# self.label_buf = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int)
# else:
# bool_old = np.zeros((self.label_buf.size,),dtype=np.bool)
# bool_old[idx_old] = True
# bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool)
# bool_new[idx_new] = True
# label_buf_new = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int)
# label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)]
# label_buf_new[idx_new[:len(idx_old)]] = self.label_buf[idx_old]
# self.label_buf = label_buf_new
#
# def _shrink_label_buf(self, idx_old, idx_new, bool_list):
# bool_old = np.zeros((self.label_buf.size,),dtype=np.bool)
# bool_old[idx_old] = True
# bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool)
# bool_new[idx_new] = True
# label_buf_new = np.empty(self._highest_parent_.param_array.shape, dtype=np.int)
# label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)]
# label_buf_new[idx_new] = self.label_buf[idx_old[bool_list]]
# self.label_buf = label_buf_new
def _check_change(self): def _check_change(self):
changed = False changed = False
@ -237,5 +332,3 @@ class Tie(Parameterized):