some speedup for tie

This commit is contained in:
Zhenwen Dai 2014-10-30 18:08:37 +00:00
parent c461e0cfd3
commit 7ca133977b

View file

@ -473,7 +473,7 @@ class Tie(Parameterized):
self._untie_ = None
else:
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
self._traverse_param(lambda x:np.put(self.label_buf,xrange(self._highest_parent_._offset_for(x),x.size),x.tie.flat), (self._highest_parent_,), [])
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True
@ -577,26 +577,60 @@ class Tie(Parameterized):
self.update_model(True)
def _check_change(self):
changed = False
changed = [0]
if self.tied_param is not None:
for i in xrange(self.tied_param.size):
b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
if b.sum()==0:
# All the tied parameters are the same
continue
elif b.sum()==1:
# One of the tied parameter is different.
# It must be recently changed one.
# The rest will be set to its value.
val = self._highest_parent_.param_array[b0][b][0]
self._highest_parent_.param_array[b0] = val
else:
# It is most likely that the tie parameter is changed.
# Set all the tied parameter to the value of tie parameter.
self._highest_parent_.param_array[b0] = self.tied_param[i]
changed = True
return changed
from scipy import weave
from ...util.misc import param_to_array
param_array = self._highest_parent_.param_array
tied_param = param_to_array(self.tied_param)
tie_l = np.zeros_like(tied_param,dtype=np.int)
t_pairs = self.tie_pairs
t_pairs_size = self.tie_pairs.size
code="""
for(int i=0;i<t_pairs_size;i+=2) {
int pidx = t_pairs[i];
int tidx = t_pairs[i+1];
if(param_array[pidx] != tied_param[tidx]) {
if(tie_l[tidx]==0) {
tie_l[tidx] = pidx;
} else if(tie_l[tidx]>0) {
tie_l[tidx] = -1;
}
changed[0] = 1;
}
}
for(int i=0;i<t_pairs_size;i+=2) {
int pidx = t_pairs[i];
int tidx = t_pairs[i+1];
if(tie_l[tidx]>0) {
param_array[pidx] = param_array[tie_l[tidx]];
tied_param[tidx] = param_array[pidx];
} else if(tie_l[tidx]==-1) {
param_array[pidx] = tied_param[tidx];
}
}
"""
weave.inline(code, arg_names=['param_array','t_pairs','t_pairs_size','tied_param','tie_l','changed'])
# for i in xrange(self.tied_param.size):
# b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
# b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
# if b.sum()==0:
# # All the tied parameters are the same
# continue
# elif b.sum()==1:
# # One of the tied parameter is different.
# # It must be recently changed one.
# # The rest will be set to its value.
# val = self._highest_parent_.param_array[b0][b][0]
# self._highest_parent_.param_array[b0] = val
# else:
# # It is most likely that the tie parameter is changed.
# # Set all the tied parameter to the value of tie parameter.
# self._highest_parent_.param_array[b0] = self.tied_param[i]
# changed = True
return False if changed[0]==0 else True
def _parameters_changed_notification(self, me, which=None):
if which is not self: