bug fix: tie framework

This commit is contained in:
Zhenwen Dai 2014-10-21 23:46:32 +01:00
parent 0cef82db29
commit 7d5b4f2769

View file

@ -477,30 +477,31 @@ class Tie(Parameterized):
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True
self.tie_pairs = np.empty(((self.label_buf>0).sum()-self.tied_param.size,2),dtype=np.uint32)
try:
from scipy import weave
self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size)
ltoi = self._label_to_idx
t_start=int(self.buf_idx[0])
t_end=int(self.buf_idx[-1])
label_buf = self.label_buf
buf_size = self.label_buf.size
t_pairs = self.tie_pairs
code = """
int j=0;
for(int i=0;i<buf_size;i++) {
if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) {
t_pairs[j*2] = i;
t_pairs[j*2+1] = ltoi[label_buf[i]];
j++;
self.tie_pairs = np.empty(((np.logical_not(self._untie_)).sum(),2),dtype=np.uint32)
if self.tie_pairs.size>0:
try:
from scipy import weave
self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size)
ltoi = self._label_to_idx
t_start=int(self.buf_idx[0])
t_end=int(self.buf_idx[-1])
label_buf = self.label_buf
buf_size = self.label_buf.size
t_pairs = self.tie_pairs
code = """
int j=0;
for(int i=0;i<buf_size;i++) {
if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) {
t_pairs[j*2] = i;
t_pairs[j*2+1] = ltoi[label_buf[i]];
j++;
}
}
}
"""
weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs'])
except:
pass
"""
weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs'])
except:
pass
assert(np.all(self.tied_param.tie>0))
def _keepParamList(self,plist):