some speedup for tie

This commit is contained in:
Zhenwen Dai 2014-10-30 18:08:37 +00:00
parent c461e0cfd3
commit 7ca133977b

View file

@ -473,7 +473,7 @@ class Tie(Parameterized):
self._untie_ = None self._untie_ = None
else: else:
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32) self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), []) self._traverse_param(lambda x:np.put(self.label_buf,xrange(self._highest_parent_._offset_for(x),x.size),x.tie.flat), (self._highest_parent_,), [])
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param) self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0 self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True self._untie_[self.buf_idx] = True
@ -577,26 +577,60 @@ class Tie(Parameterized):
self.update_model(True) self.update_model(True)
def _check_change(self): def _check_change(self):
changed = False changed = [0]
if self.tied_param is not None: if self.tied_param is not None:
for i in xrange(self.tied_param.size):
b0 = self.label_buf==self.label_buf[self.buf_idx[i]] from scipy import weave
b = self._highest_parent_.param_array[b0]!=self.tied_param[i] from ...util.misc import param_to_array
if b.sum()==0: param_array = self._highest_parent_.param_array
# All the tied parameters are the same tied_param = param_to_array(self.tied_param)
continue tie_l = np.zeros_like(tied_param,dtype=np.int)
elif b.sum()==1: t_pairs = self.tie_pairs
# One of the tied parameter is different. t_pairs_size = self.tie_pairs.size
# It must be recently changed one. code="""
# The rest will be set to its value. for(int i=0;i<t_pairs_size;i+=2) {
val = self._highest_parent_.param_array[b0][b][0] int pidx = t_pairs[i];
self._highest_parent_.param_array[b0] = val int tidx = t_pairs[i+1];
else: if(param_array[pidx] != tied_param[tidx]) {
# It is most likely that the tie parameter is changed. if(tie_l[tidx]==0) {
# Set all the tied parameter to the value of tie parameter. tie_l[tidx] = pidx;
self._highest_parent_.param_array[b0] = self.tied_param[i] } else if(tie_l[tidx]>0) {
changed = True tie_l[tidx] = -1;
return changed }
changed[0] = 1;
}
}
for(int i=0;i<t_pairs_size;i+=2) {
int pidx = t_pairs[i];
int tidx = t_pairs[i+1];
if(tie_l[tidx]>0) {
param_array[pidx] = param_array[tie_l[tidx]];
tied_param[tidx] = param_array[pidx];
} else if(tie_l[tidx]==-1) {
param_array[pidx] = tied_param[tidx];
}
}
"""
weave.inline(code, arg_names=['param_array','t_pairs','t_pairs_size','tied_param','tie_l','changed'])
# for i in xrange(self.tied_param.size):
# b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
# b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
# if b.sum()==0:
# # All the tied parameters are the same
# continue
# elif b.sum()==1:
# # One of the tied parameter is different.
# # It must be recently changed one.
# # The rest will be set to its value.
# val = self._highest_parent_.param_array[b0][b][0]
# self._highest_parent_.param_array[b0] = val
# else:
# # It is most likely that the tie parameter is changed.
# # Set all the tied parameter to the value of tie parameter.
# self._highest_parent_.param_array[b0] = self.tied_param[i]
# changed = True
return False if changed[0]==0 else True
def _parameters_changed_notification(self, me, which=None): def _parameters_changed_notification(self, me, which=None):
if which is not self: if which is not self: