bug fix: tie framework

This commit is contained in:
Zhenwen Dai 2014-10-21 23:46:32 +01:00
parent 0cef82db29
commit 7d5b4f2769

View file

@ -477,30 +477,31 @@ class Tie(Parameterized):
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param) self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0 self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True self._untie_[self.buf_idx] = True
self.tie_pairs = np.empty(((self.label_buf>0).sum()-self.tied_param.size,2),dtype=np.uint32) self.tie_pairs = np.empty(((np.logical_not(self._untie_)).sum(),2),dtype=np.uint32)
try: if self.tie_pairs.size>0:
from scipy import weave try:
self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32) from scipy import weave
self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size) self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
ltoi = self._label_to_idx self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size)
t_start=int(self.buf_idx[0]) ltoi = self._label_to_idx
t_end=int(self.buf_idx[-1]) t_start=int(self.buf_idx[0])
label_buf = self.label_buf t_end=int(self.buf_idx[-1])
buf_size = self.label_buf.size label_buf = self.label_buf
t_pairs = self.tie_pairs buf_size = self.label_buf.size
code = """ t_pairs = self.tie_pairs
int j=0; code = """
for(int i=0;i<buf_size;i++) { int j=0;
if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) { for(int i=0;i<buf_size;i++) {
t_pairs[j*2] = i; if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) {
t_pairs[j*2+1] = ltoi[label_buf[i]]; t_pairs[j*2] = i;
j++; t_pairs[j*2+1] = ltoi[label_buf[i]];
j++;
}
} }
} """
""" weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs'])
weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs']) except:
except: pass
pass
assert(np.all(self.tied_param.tie>0)) assert(np.all(self.tied_param.tie>0))
def _keepParamList(self,plist): def _keepParamList(self,plist):